forked from 626_privacy/tensorflow_privacy
Changes to make Tensorflow Privacy compatible with TF 2.0.
PiperOrigin-RevId: 277561553
This commit is contained in:
parent
8a80c1a745
commit
d69879d360
23 changed files with 176 additions and 261 deletions
|
@ -19,18 +19,12 @@ from __future__ import print_function
|
|||
|
||||
import collections
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.analysis import tensor_buffer
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
nest = tf.contrib.framework.nest
|
||||
else:
|
||||
nest = tf.nest
|
||||
|
||||
SampleEntry = collections.namedtuple( # pylint: disable=invalid-name
|
||||
'SampleEntry', ['population_size', 'selection_probability', 'queries'])
|
||||
|
||||
|
@ -83,7 +77,7 @@ class PrivacyLedger(object):
|
|||
if tf.executing_eagerly():
|
||||
if tf.equal(selection_probability, 0):
|
||||
raise ValueError('Selection probability cannot be 0.')
|
||||
init_capacity = tf.cast(tf.ceil(1 / selection_probability), tf.int32)
|
||||
init_capacity = tf.cast(tf.math.ceil(1 / selection_probability), tf.int32)
|
||||
else:
|
||||
if selection_probability == 0:
|
||||
raise ValueError('Selection probability cannot be 0.')
|
||||
|
@ -102,12 +96,7 @@ class PrivacyLedger(object):
|
|||
initial_value=0.0, trainable=False, name='sample_count')
|
||||
self._query_count = tf.Variable(
|
||||
initial_value=0.0, trainable=False, name='query_count')
|
||||
try:
|
||||
# Newer versions of TF
|
||||
self._cs = tf.CriticalSection()
|
||||
except AttributeError:
|
||||
# Older versions of TF
|
||||
self._cs = tf.contrib.framework.CriticalSection()
|
||||
self._cs = tf.CriticalSection()
|
||||
|
||||
def record_sum_query(self, l2_norm_bound, noise_stddev):
|
||||
"""Records that a query was issued.
|
||||
|
@ -122,7 +111,7 @@ class PrivacyLedger(object):
|
|||
|
||||
def _do_record_query():
|
||||
with tf.control_dependencies(
|
||||
[tf.assign(self._query_count, self._query_count + 1)]):
|
||||
[tf.compat.v1.assign(self._query_count, self._query_count + 1)]):
|
||||
return self._query_buffer.append(
|
||||
[self._sample_count, l2_norm_bound, noise_stddev])
|
||||
|
||||
|
@ -131,14 +120,14 @@ class PrivacyLedger(object):
|
|||
def finalize_sample(self):
|
||||
"""Finalizes sample and records sample ledger entry."""
|
||||
with tf.control_dependencies([
|
||||
tf.assign(self._sample_var, [
|
||||
tf.compat.v1.assign(self._sample_var, [
|
||||
self._population_size, self._selection_probability,
|
||||
self._query_count
|
||||
])
|
||||
]):
|
||||
with tf.control_dependencies([
|
||||
tf.assign(self._sample_count, self._sample_count + 1),
|
||||
tf.assign(self._query_count, 0)
|
||||
tf.compat.v1.assign(self._sample_count, self._sample_count + 1),
|
||||
tf.compat.v1.assign(self._query_count, 0)
|
||||
]):
|
||||
return self._sample_buffer.append(self._sample_var)
|
||||
|
||||
|
@ -246,12 +235,12 @@ class QueryWithLedger(dp_query.DPQuery):
|
|||
def get_noised_result(self, sample_state, global_state):
|
||||
"""Ensures sample is recorded to the ledger and returns noised result."""
|
||||
# Ensure sample_state is fully aggregated before calling get_noised_result.
|
||||
with tf.control_dependencies(nest.flatten(sample_state)):
|
||||
with tf.control_dependencies(tf.nest.flatten(sample_state)):
|
||||
result, new_global_state = self._query.get_noised_result(
|
||||
sample_state, global_state)
|
||||
# Ensure inner queries have recorded before finalizing.
|
||||
with tf.control_dependencies(nest.flatten(result)):
|
||||
with tf.control_dependencies(tf.nest.flatten(result)):
|
||||
finalize = self._ledger.finalize_sample()
|
||||
# Ensure finalizing happens.
|
||||
with tf.control_dependencies([finalize]):
|
||||
return nest.map_structure(tf.identity, result), new_global_state
|
||||
return tf.nest.map_structure(tf.identity, result), new_global_state
|
||||
|
|
|
@ -25,7 +25,7 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
|
|||
from tensorflow_privacy.privacy.dp_query import nested_query
|
||||
from tensorflow_privacy.privacy.dp_query import test_utils
|
||||
|
||||
tf.enable_eager_execution()
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
|
||||
|
||||
class PrivacyLedgerTest(tf.test.TestCase):
|
||||
|
@ -63,8 +63,8 @@ class PrivacyLedgerTest(tf.test.TestCase):
|
|||
query, population_size, selection_probability)
|
||||
|
||||
# First sample.
|
||||
tf.assign(population_size, 10)
|
||||
tf.assign(selection_probability, 0.1)
|
||||
tf.compat.v1.assign(population_size, 10)
|
||||
tf.compat.v1.assign(selection_probability, 0.1)
|
||||
test_utils.run_query(query, [record1, record2])
|
||||
|
||||
expected_queries = [[10.0, 0.0]]
|
||||
|
@ -75,8 +75,8 @@ class PrivacyLedgerTest(tf.test.TestCase):
|
|||
self.assertAllClose(sample_1.queries, expected_queries)
|
||||
|
||||
# Second sample.
|
||||
tf.assign(population_size, 20)
|
||||
tf.assign(selection_probability, 0.2)
|
||||
tf.compat.v1.assign(population_size, 20)
|
||||
tf.compat.v1.assign(selection_probability, 0.2)
|
||||
test_utils.run_query(query, [record1, record2])
|
||||
|
||||
formatted = query.ledger.get_formatted_ledger_eager()
|
||||
|
@ -106,8 +106,8 @@ class PrivacyLedgerTest(tf.test.TestCase):
|
|||
record2 = [5.0, [1.0, 2.0]]
|
||||
|
||||
# First sample.
|
||||
tf.assign(population_size, 10)
|
||||
tf.assign(selection_probability, 0.1)
|
||||
tf.compat.v1.assign(population_size, 10)
|
||||
tf.compat.v1.assign(selection_probability, 0.1)
|
||||
test_utils.run_query(query, [record1, record2])
|
||||
|
||||
expected_queries = [[4.0, 2.0], [5.0, 1.0]]
|
||||
|
@ -118,8 +118,8 @@ class PrivacyLedgerTest(tf.test.TestCase):
|
|||
self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries))
|
||||
|
||||
# Second sample.
|
||||
tf.assign(population_size, 20)
|
||||
tf.assign(selection_probability, 0.2)
|
||||
tf.compat.v1.assign(population_size, 20)
|
||||
tf.compat.v1.assign(selection_probability, 0.2)
|
||||
test_utils.run_query(query, [record1, record2])
|
||||
|
||||
formatted = query.ledger.get_formatted_ledger_eager()
|
||||
|
|
|
@ -50,10 +50,10 @@ class TensorBuffer(object):
|
|||
raise ValueError('Shape cannot be scalar.')
|
||||
shape = [capacity] + shape
|
||||
|
||||
with tf.variable_scope(self._name):
|
||||
with tf.compat.v1.variable_scope(self._name):
|
||||
# We need to use a placeholder as the initial value to allow resizing.
|
||||
self._buffer = tf.Variable(
|
||||
initial_value=tf.placeholder_with_default(
|
||||
self._buffer = tf.compat.v1.Variable(
|
||||
initial_value=tf.compat.v1.placeholder_with_default(
|
||||
tf.zeros(shape, dtype), shape=None),
|
||||
trainable=False,
|
||||
name='buffer',
|
||||
|
@ -82,38 +82,39 @@ class TensorBuffer(object):
|
|||
padding = tf.zeros_like(self._buffer, self._buffer.dtype)
|
||||
new_buffer = tf.concat([self._buffer, padding], axis=0)
|
||||
if tf.executing_eagerly():
|
||||
with tf.variable_scope(self._name, reuse=True):
|
||||
self._buffer = tf.get_variable(
|
||||
with tf.compat.v1.variable_scope(self._name, reuse=True):
|
||||
self._buffer = tf.compat.v1.get_variable(
|
||||
name='buffer',
|
||||
dtype=self._dtype,
|
||||
initializer=new_buffer,
|
||||
trainable=False)
|
||||
return self._buffer, tf.assign(self._capacity,
|
||||
tf.multiply(self._capacity, 2))
|
||||
return self._buffer, tf.compat.v1.assign(
|
||||
self._capacity, tf.multiply(self._capacity, 2))
|
||||
else:
|
||||
return tf.assign(
|
||||
return tf.compat.v1.assign(
|
||||
self._buffer, new_buffer,
|
||||
validate_shape=False), tf.assign(self._capacity,
|
||||
tf.multiply(self._capacity, 2))
|
||||
validate_shape=False), tf.compat.v1.assign(
|
||||
self._capacity, tf.multiply(self._capacity, 2))
|
||||
|
||||
update_buffer, update_capacity = tf.cond(
|
||||
tf.equal(self._current_size, self._capacity),
|
||||
_double_capacity, lambda: (self._buffer, self._capacity))
|
||||
pred=tf.equal(self._current_size, self._capacity),
|
||||
true_fn=_double_capacity,
|
||||
false_fn=lambda: (self._buffer, self._capacity))
|
||||
|
||||
with tf.control_dependencies([update_buffer, update_capacity]):
|
||||
with tf.control_dependencies([
|
||||
tf.assert_less(
|
||||
tf.compat.v1.assert_less(
|
||||
self._current_size,
|
||||
self._capacity,
|
||||
message='Appending past end of TensorBuffer.'),
|
||||
tf.assert_equal(
|
||||
tf.shape(value),
|
||||
tf.shape(self._buffer)[1:],
|
||||
tf.compat.v1.assert_equal(
|
||||
tf.shape(input=value),
|
||||
tf.shape(input=self._buffer)[1:],
|
||||
message='Appending value of inconsistent shape.')
|
||||
]):
|
||||
with tf.control_dependencies(
|
||||
[tf.assign(self._buffer[self._current_size, :], value)]):
|
||||
return tf.assign_add(self._current_size, 1)
|
||||
[tf.compat.v1.assign(self._buffer[self._current_size, :], value)]):
|
||||
return tf.compat.v1.assign_add(self._current_size, 1)
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
|
|
|
@ -21,7 +21,7 @@ import tensorflow as tf
|
|||
|
||||
from tensorflow_privacy.privacy.analysis import tensor_buffer
|
||||
|
||||
tf.enable_eager_execution()
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
|
||||
|
||||
class TensorBufferTest(tf.test.TestCase):
|
||||
|
|
|
@ -38,7 +38,7 @@ class TensorBufferTest(tf.test.TestCase):
|
|||
values = my_buffer.values
|
||||
current_size = my_buffer.current_size
|
||||
capacity = my_buffer.capacity
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
|
||||
v, cs, cap = sess.run([values, current_size, capacity])
|
||||
self.assertAllEqual(v, [value1, value2])
|
||||
|
@ -60,7 +60,7 @@ class TensorBufferTest(tf.test.TestCase):
|
|||
values = my_buffer.values
|
||||
current_size = my_buffer.current_size
|
||||
capacity = my_buffer.capacity
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
|
||||
v, cs, cap = sess.run([values, current_size, capacity])
|
||||
self.assertAllEqual(v, [value1, value2, value3])
|
||||
|
|
|
@ -47,13 +47,8 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import tensorflow as tf
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
nest = tf.contrib.framework.nest
|
||||
else:
|
||||
nest = tf.nest
|
||||
|
||||
|
||||
class DPQuery(object):
|
||||
|
@ -206,7 +201,7 @@ class DPQuery(object):
|
|||
def zeros_like(arg):
|
||||
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
|
||||
try:
|
||||
arg = tf.convert_to_tensor(arg)
|
||||
arg = tf.convert_to_tensor(value=arg)
|
||||
except TypeError:
|
||||
pass
|
||||
return tf.zeros(arg.shape, arg.dtype)
|
||||
|
@ -216,10 +211,10 @@ class SumAggregationDPQuery(DPQuery):
|
|||
"""Base class for DPQueries that aggregate via sum."""
|
||||
|
||||
def initial_sample_state(self, template):
|
||||
return nest.map_structure(zeros_like, template)
|
||||
return tf.nest.map_structure(zeros_like, template)
|
||||
|
||||
def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
|
||||
return nest.map_structure(tf.add, sample_state, preprocessed_record)
|
||||
return tf.nest.map_structure(tf.add, sample_state, preprocessed_record)
|
||||
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
return nest.map_structure(tf.add, sample_state_1, sample_state_2)
|
||||
return tf.nest.map_structure(tf.add, sample_state_1, sample_state_2)
|
||||
|
|
|
@ -27,11 +27,6 @@ import tensorflow as tf
|
|||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
from tensorflow_privacy.privacy.dp_query import normalized_query
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
nest = tf.contrib.framework.nest
|
||||
else:
|
||||
nest = tf.nest
|
||||
|
||||
|
||||
class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements DPQuery interface for Gaussian sum queries.
|
||||
|
@ -70,7 +65,7 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
return global_state.l2_norm_clip
|
||||
|
||||
def initial_sample_state(self, template):
|
||||
return nest.map_structure(
|
||||
return tf.nest.map_structure(
|
||||
dp_query.zeros_like, template)
|
||||
|
||||
def preprocess_record_impl(self, params, record):
|
||||
|
@ -86,9 +81,9 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
before clipping.
|
||||
"""
|
||||
l2_norm_clip = params
|
||||
record_as_list = nest.flatten(record)
|
||||
record_as_list = tf.nest.flatten(record)
|
||||
clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
|
||||
return nest.pack_sequence_as(record, clipped_as_list), norm
|
||||
return tf.nest.pack_sequence_as(record, clipped_as_list), norm
|
||||
|
||||
def preprocess_record(self, params, record):
|
||||
preprocessed_record, _ = self.preprocess_record_impl(params, record)
|
||||
|
@ -98,11 +93,14 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
"""See base class."""
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
def add_noise(v):
|
||||
return v + tf.random_normal(tf.shape(v), stddev=global_state.stddev)
|
||||
return v + tf.random.normal(
|
||||
tf.shape(input=v), stddev=global_state.stddev)
|
||||
else:
|
||||
random_normal = tf.random_normal_initializer(stddev=global_state.stddev)
|
||||
random_normal = tf.compat.v1.random_normal_initializer(
|
||||
stddev=global_state.stddev)
|
||||
|
||||
def add_noise(v):
|
||||
return v + random_normal(tf.shape(v))
|
||||
return v + random_normal(tf.shape(input=v))
|
||||
|
||||
if self._ledger:
|
||||
dependencies = [
|
||||
|
@ -112,7 +110,7 @@ class GaussianSumQuery(dp_query.SumAggregationDPQuery):
|
|||
else:
|
||||
dependencies = []
|
||||
with tf.control_dependencies(dependencies):
|
||||
return nest.map_structure(add_noise, sample_state), global_state
|
||||
return tf.nest.map_structure(add_noise, sample_state), global_state
|
||||
|
||||
|
||||
class GaussianAverageQuery(normalized_query.NormalizedQuery):
|
||||
|
|
|
@ -59,13 +59,14 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
record2 = tf.constant([4.0, -3.0]) # Not clipped.
|
||||
|
||||
l2_norm_clip = tf.Variable(5.0)
|
||||
l2_norm_clip_placeholder = tf.placeholder(tf.float32)
|
||||
assign_l2_norm_clip = tf.assign(l2_norm_clip, l2_norm_clip_placeholder)
|
||||
l2_norm_clip_placeholder = tf.compat.v1.placeholder(tf.float32)
|
||||
assign_l2_norm_clip = tf.compat.v1.assign(l2_norm_clip,
|
||||
l2_norm_clip_placeholder)
|
||||
query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip=l2_norm_clip, stddev=0.0)
|
||||
query_result, _ = test_utils.run_query(query, [record1, record2])
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
result = sess.run(query_result)
|
||||
expected = [1.0, 1.0]
|
||||
self.assertAllClose(result, expected)
|
||||
|
|
|
@ -19,16 +19,10 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
nest = tf.contrib.framework.nest
|
||||
else:
|
||||
nest = tf.nest
|
||||
|
||||
|
||||
class NestedQuery(dp_query.DPQuery):
|
||||
"""Implements DPQuery interface for structured queries.
|
||||
|
@ -59,7 +53,8 @@ class NestedQuery(dp_query.DPQuery):
|
|||
def _map_to_queries(self, fn, *inputs, **kwargs):
|
||||
def caller(query, *args):
|
||||
return getattr(query, fn)(*args, **kwargs)
|
||||
return nest.map_structure_up_to(
|
||||
|
||||
return tf.contrib.framework.nest.map_structure_up_to(
|
||||
self._queries, caller, self._queries, *inputs)
|
||||
|
||||
def set_ledger(self, ledger):
|
||||
|
@ -110,7 +105,7 @@ class NestedQuery(dp_query.DPQuery):
|
|||
'get_noised_result', sample_state, global_state)
|
||||
|
||||
flat_estimates, flat_new_global_states = zip(
|
||||
*nest.flatten_up_to(self._queries, estimates_and_new_global_states))
|
||||
return (
|
||||
nest.pack_sequence_as(self._queries, flat_estimates),
|
||||
nest.pack_sequence_as(self._queries, flat_new_global_states))
|
||||
*tf.contrib.framework.nest.flatten_up_to(
|
||||
self._queries, estimates_and_new_global_states))
|
||||
return (tf.nest.pack_sequence_as(self._queries, flat_estimates),
|
||||
tf.nest.pack_sequence_as(self._queries, flat_new_global_states))
|
||||
|
|
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||
|
||||
|
||||
from absl.testing import parameterized
|
||||
from distutils.version import LooseVersion
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -28,10 +27,6 @@ from tensorflow_privacy.privacy.dp_query import gaussian_query
|
|||
from tensorflow_privacy.privacy.dp_query import nested_query
|
||||
from tensorflow_privacy.privacy.dp_query import test_utils
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
nest = tf.contrib.framework.nest
|
||||
else:
|
||||
nest = tf.nest
|
||||
|
||||
_basic_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
|
||||
|
||||
|
@ -127,7 +122,7 @@ class NestedQueryTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
noised_averages = []
|
||||
for _ in range(1000):
|
||||
noised_averages.append(nest.flatten(sess.run(query_result)))
|
||||
noised_averages.append(tf.nest.flatten(sess.run(query_result)))
|
||||
|
||||
result_stddev = np.std(noised_averages, 0)
|
||||
avg_stddev = sum_stddev / denominator
|
||||
|
|
|
@ -17,16 +17,10 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
nest = tf.contrib.framework.nest
|
||||
else:
|
||||
nest = tf.nest
|
||||
|
||||
|
||||
class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
|
||||
"""Implements DPQuery interface for a sum query with no privacy.
|
||||
|
@ -52,12 +46,12 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
|||
|
||||
def preprocess_record(self, params, record, weight=1):
|
||||
"""Multiplies record by weight."""
|
||||
weighted_record = nest.map_structure(lambda t: weight * t, record)
|
||||
weighted_record = tf.nest.map_structure(lambda t: weight * t, record)
|
||||
return (weighted_record, tf.cast(weight, tf.float32))
|
||||
|
||||
def accumulate_record(self, params, sample_state, record, weight=1):
|
||||
"""Accumulates record, multiplying by weight."""
|
||||
weighted_record = nest.map_structure(lambda t: weight * t, record)
|
||||
weighted_record = tf.nest.map_structure(lambda t: weight * t, record)
|
||||
return self.accumulate_preprocessed_record(
|
||||
sample_state, (weighted_record, tf.cast(weight, tf.float32)))
|
||||
|
||||
|
@ -65,6 +59,5 @@ class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
|
|||
"""See base class."""
|
||||
sum_state, denominator = sample_state
|
||||
|
||||
return (
|
||||
nest.map_structure(lambda t: t / denominator, sum_state),
|
||||
global_state)
|
||||
return (tf.nest.map_structure(lambda t: t / denominator,
|
||||
sum_state), global_state)
|
||||
|
|
|
@ -21,16 +21,10 @@ from __future__ import print_function
|
|||
|
||||
import collections
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.dp_query import dp_query
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
nest = tf.contrib.framework.nest
|
||||
else:
|
||||
nest = tf.nest
|
||||
|
||||
|
||||
class NormalizedQuery(dp_query.DPQuery):
|
||||
"""DPQuery for queries with a DPQuery numerator and fixed denominator."""
|
||||
|
@ -89,7 +83,7 @@ class NormalizedQuery(dp_query.DPQuery):
|
|||
def normalize(v):
|
||||
return tf.truediv(v, global_state.denominator)
|
||||
|
||||
return (nest.map_structure(normalize, noised_sum),
|
||||
return (tf.nest.map_structure(normalize, noised_sum),
|
||||
self._GlobalState(new_sum_global_state, global_state.denominator))
|
||||
|
||||
def merge_sample_states(self, sample_state_1, sample_state_2):
|
||||
|
|
|
@ -26,7 +26,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -34,11 +33,6 @@ from tensorflow_privacy.privacy.dp_query import dp_query
|
|||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
from tensorflow_privacy.privacy.dp_query import normalized_query
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
nest = tf.contrib.framework.nest
|
||||
else:
|
||||
nest = tf.nest
|
||||
|
||||
|
||||
class QuantileAdaptiveClipSumQuery(dp_query.DPQuery):
|
||||
"""DPQuery for sum queries with adaptive clipping.
|
||||
|
|
|
@ -25,7 +25,7 @@ from tensorflow_privacy.privacy.analysis import privacy_ledger
|
|||
from tensorflow_privacy.privacy.dp_query import quantile_adaptive_clip_sum_query
|
||||
from tensorflow_privacy.privacy.dp_query import test_utils
|
||||
|
||||
tf.enable_eager_execution()
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
|
||||
|
||||
class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase):
|
||||
|
@ -211,7 +211,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase):
|
|||
global_state = query.initial_global_state()
|
||||
|
||||
for t in range(50):
|
||||
tf.assign(learning_rate, 1.0 / np.sqrt(t+1))
|
||||
tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1))
|
||||
_, global_state = test_utils.run_query(query, records, global_state)
|
||||
|
||||
actual_clip = global_state.l2_norm_clip
|
||||
|
@ -237,7 +237,7 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase):
|
|||
global_state = query.initial_global_state()
|
||||
|
||||
for t in range(50):
|
||||
tf.assign(learning_rate, 1.0 / np.sqrt(t+1))
|
||||
tf.compat.v1.assign(learning_rate, 1.0 / np.sqrt(t + 1))
|
||||
_, global_state = test_utils.run_query(query, records, global_state)
|
||||
|
||||
actual_clip = global_state.l2_norm_clip
|
||||
|
@ -264,8 +264,8 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase):
|
|||
query, population_size, selection_probability)
|
||||
|
||||
# First sample.
|
||||
tf.assign(population_size, 10)
|
||||
tf.assign(selection_probability, 0.1)
|
||||
tf.compat.v1.assign(population_size, 10)
|
||||
tf.compat.v1.assign(selection_probability, 0.1)
|
||||
_, global_state = test_utils.run_query(query, [record1, record2])
|
||||
|
||||
expected_queries = [[10.0, 10.0], [0.5, 0.0]]
|
||||
|
@ -276,8 +276,8 @@ class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase):
|
|||
self.assertAllClose(sample_1.queries, expected_queries)
|
||||
|
||||
# Second sample.
|
||||
tf.assign(population_size, 20)
|
||||
tf.assign(selection_probability, 0.2)
|
||||
tf.compat.v1.assign(population_size, 20)
|
||||
tf.compat.v1.assign(selection_probability, 0.2)
|
||||
test_utils.run_query(query, [record1, record2], global_state)
|
||||
|
||||
formatted = query.ledger.get_formatted_ledger_eager()
|
||||
|
|
|
@ -17,30 +17,21 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
from absl import logging
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.analysis import privacy_ledger
|
||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
nest = tf.contrib.framework.nest
|
||||
else:
|
||||
nest = tf.nest
|
||||
|
||||
|
||||
def make_optimizer_class(cls):
|
||||
"""Constructs a DP optimizer class from an existing one."""
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
parent_code = tf.train.Optimizer.compute_gradients.__code__
|
||||
child_code = cls.compute_gradients.__code__
|
||||
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
||||
else:
|
||||
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
|
||||
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
|
||||
GATE_OP = None # pylint: disable=invalid-name
|
||||
parent_code = tf.compat.v1.train.Optimizer.compute_gradients.__code__
|
||||
child_code = cls.compute_gradients.__code__
|
||||
GATE_OP = tf.compat.v1.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
||||
if child_code is not parent_code:
|
||||
tf.logging.warning(
|
||||
logging.warning(
|
||||
'WARNING: Calling make_optimizer_class() on class %s that overrides '
|
||||
'method compute_gradients(). Check to ensure that '
|
||||
'make_optimizer_class() does not interfere with overridden version.',
|
||||
|
@ -92,7 +83,7 @@ def make_optimizer_class(cls):
|
|||
|
||||
vector_loss = loss()
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = tf.shape(vector_loss)[0]
|
||||
self._num_microbatches = tf.shape(input=vector_loss)[0]
|
||||
sample_state = self._dp_sum_query.initial_sample_state(var_list)
|
||||
microbatches_losses = tf.reshape(vector_loss,
|
||||
[self._num_microbatches, -1])
|
||||
|
@ -101,7 +92,8 @@ def make_optimizer_class(cls):
|
|||
|
||||
def process_microbatch(i, sample_state):
|
||||
"""Process one microbatch (record) with privacy helper."""
|
||||
microbatch_loss = tf.reduce_mean(tf.gather(microbatches_losses, [i]))
|
||||
microbatch_loss = tf.reduce_mean(
|
||||
input_tensor=tf.gather(microbatches_losses, [i]))
|
||||
grads = gradient_tape.gradient(microbatch_loss, var_list)
|
||||
sample_state = self._dp_sum_query.accumulate_record(
|
||||
sample_params, sample_state, grads)
|
||||
|
@ -117,7 +109,7 @@ def make_optimizer_class(cls):
|
|||
def normalize(v):
|
||||
return v / tf.cast(self._num_microbatches, tf.float32)
|
||||
|
||||
final_grads = nest.map_structure(normalize, grad_sums)
|
||||
final_grads = tf.nest.map_structure(normalize, grad_sums)
|
||||
|
||||
grads_and_vars = list(zip(final_grads, var_list))
|
||||
return grads_and_vars
|
||||
|
@ -132,7 +124,7 @@ def make_optimizer_class(cls):
|
|||
# although that still wouldn't be quite correct because it would be
|
||||
# sampling from the dataset without replacement.
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = tf.shape(loss)[0]
|
||||
self._num_microbatches = tf.shape(input=loss)[0]
|
||||
|
||||
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
||||
sample_params = (
|
||||
|
@ -141,8 +133,8 @@ def make_optimizer_class(cls):
|
|||
def process_microbatch(i, sample_state):
|
||||
"""Process one microbatch (record) with privacy helper."""
|
||||
grads, _ = zip(*super(cls, self).compute_gradients(
|
||||
tf.reduce_mean(tf.gather(microbatches_losses,
|
||||
[i])), var_list, gate_gradients,
|
||||
tf.reduce_mean(input_tensor=tf.gather(
|
||||
microbatches_losses, [i])), var_list, gate_gradients,
|
||||
aggregation_method, colocate_gradients_with_ops, grad_loss))
|
||||
grads_list = [
|
||||
g if g is not None else tf.zeros_like(v)
|
||||
|
@ -154,8 +146,8 @@ def make_optimizer_class(cls):
|
|||
|
||||
if var_list is None:
|
||||
var_list = (
|
||||
tf.trainable_variables() + tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||
tf.compat.v1.trainable_variables() + tf.compat.v1.get_collection(
|
||||
tf.compat.v1.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||
|
||||
sample_state = self._dp_sum_query.initial_sample_state(var_list)
|
||||
|
||||
|
@ -169,7 +161,8 @@ def make_optimizer_class(cls):
|
|||
cond_fn = lambda i, _: tf.less(i, self._num_microbatches)
|
||||
body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long
|
||||
idx = tf.constant(0)
|
||||
_, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])
|
||||
_, sample_state = tf.while_loop(
|
||||
cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state])
|
||||
|
||||
grad_sums, self._global_state = (
|
||||
self._dp_sum_query.get_noised_result(
|
||||
|
@ -178,7 +171,7 @@ def make_optimizer_class(cls):
|
|||
def normalize(v):
|
||||
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
|
||||
|
||||
final_grads = nest.map_structure(normalize, grad_sums)
|
||||
final_grads = tf.nest.map_structure(normalize, grad_sums)
|
||||
|
||||
return list(zip(final_grads, var_list))
|
||||
|
||||
|
@ -220,14 +213,9 @@ def make_gaussian_optimizer_class(cls):
|
|||
|
||||
return DPGaussianOptimizerClass
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
AdagradOptimizer = tf.train.AdagradOptimizer
|
||||
AdamOptimizer = tf.train.AdamOptimizer
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
else:
|
||||
AdagradOptimizer = tf.optimizers.Adagrad
|
||||
AdamOptimizer = tf.optimizers.Adam
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
AdagradOptimizer = tf.compat.v1.train.AdagradOptimizer
|
||||
AdamOptimizer = tf.compat.v1.train.AdamOptimizer
|
||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
||||
|
||||
DPAdagradOptimizer = make_optimizer_class(AdagradOptimizer)
|
||||
DPAdamOptimizer = make_optimizer_class(AdamOptimizer)
|
||||
|
|
|
@ -29,11 +29,12 @@ from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
|||
class DPOptimizerEagerTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
tf.enable_eager_execution()
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
super(DPOptimizerEagerTest, self).setUp()
|
||||
|
||||
def _loss_fn(self, val0, val1):
|
||||
return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1)
|
||||
return 0.5 * tf.reduce_sum(
|
||||
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1,
|
||||
|
@ -62,7 +63,7 @@ class DPOptimizerEagerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
num_microbatches=num_microbatches,
|
||||
learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
|
||||
|
@ -87,7 +88,7 @@ class DPOptimizerEagerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
||||
|
||||
|
@ -111,7 +112,7 @@ class DPOptimizerEagerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([0.0], self.evaluate(var0))
|
||||
|
||||
|
|
|
@ -31,7 +31,8 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
def _loss(self, val0, val1):
|
||||
"""Loss function that is minimized at the mean of the input points."""
|
||||
return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1)
|
||||
return 0.5 * tf.reduce_sum(
|
||||
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
||||
|
||||
# Parameters for testing: optimizer, num_microbatches, expected answer.
|
||||
@parameterized.named_parameters(
|
||||
|
@ -61,7 +62,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
num_microbatches=num_microbatches,
|
||||
learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
|
||||
|
@ -85,7 +86,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
||||
|
||||
|
@ -108,7 +109,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([0.0], self.evaluate(var0))
|
||||
|
||||
|
@ -121,16 +122,16 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||
self.assertNear(np.std(grads), 2.0 * 4.0, 0.5)
|
||||
|
||||
@mock.patch.object(tf, 'logging')
|
||||
@mock.patch('absl.logging.warning')
|
||||
def testComputeGradientsOverrideWarning(self, mock_logging):
|
||||
|
||||
class SimpleOptimizer(tf.train.Optimizer):
|
||||
class SimpleOptimizer(tf.compat.v1.train.Optimizer):
|
||||
|
||||
def compute_gradients(self):
|
||||
return 0
|
||||
|
||||
dp_optimizer.make_optimizer_class(SimpleOptimizer)
|
||||
mock_logging.warning.assert_called_once_with(
|
||||
mock_logging.assert_called_once_with(
|
||||
'WARNING: Calling make_optimizer_class() on class %s that overrides '
|
||||
'method compute_gradients(). Check to ensure that '
|
||||
'make_optimizer_class() does not interfere with overridden version.',
|
||||
|
@ -143,15 +144,15 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
preds = tf.keras.layers.Dense(
|
||||
1, activation='linear', name='dense').apply(features['x'])
|
||||
|
||||
vector_loss = tf.squared_difference(labels, preds)
|
||||
scalar_loss = tf.reduce_mean(vector_loss)
|
||||
vector_loss = tf.math.squared_difference(labels, preds)
|
||||
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
|
||||
dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0)
|
||||
dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6)
|
||||
optimizer = dp_optimizer.DPGradientDescentOptimizer(
|
||||
dp_sum_query,
|
||||
num_microbatches=1,
|
||||
learning_rate=1.0)
|
||||
global_step = tf.train.get_global_step()
|
||||
global_step = tf.compat.v1.train.get_global_step()
|
||||
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
|
||||
return tf.estimator.EstimatorSpec(
|
||||
mode=mode, loss=scalar_loss, train_op=train_op)
|
||||
|
@ -165,7 +166,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
true_weights) + true_bias + np.random.normal(
|
||||
scale=0.1, size=(200, 1)).astype(np.float32)
|
||||
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={'x': train_data},
|
||||
y=train_labels,
|
||||
batch_size=20,
|
||||
|
@ -198,7 +199,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
learning_rate=2.0,
|
||||
unroll_microbatches=True)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
|
||||
|
@ -223,7 +224,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
num_microbatches=1,
|
||||
learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([0.0], self.evaluate(var0))
|
||||
|
||||
|
|
|
@ -17,33 +17,22 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
from absl import logging
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
nest = tf.contrib.framework.nest
|
||||
AdagradOptimizer = tf.train.AdagradOptimizer
|
||||
AdamOptimizer = tf.train.AdamOptimizer
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
parent_code = tf.train.Optimizer.compute_gradients.__code__
|
||||
GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
||||
else:
|
||||
nest = tf.nest
|
||||
AdagradOptimizer = tf.optimizers.Adagrad
|
||||
AdamOptimizer = tf.optimizers.Adam
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access
|
||||
GATE_OP = None # pylint: disable=invalid-name
|
||||
AdagradOptimizer = tf.compat.v1.train.AdagradOptimizer
|
||||
AdamOptimizer = tf.compat.v1.train.AdamOptimizer
|
||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
||||
parent_code = tf.compat.v1.train.Optimizer.compute_gradients.__code__
|
||||
GATE_OP = tf.compat.v1.train.Optimizer.GATE_OP # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def make_vectorized_optimizer_class(cls):
|
||||
"""Constructs a vectorized DP optimizer class from an existing one."""
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
child_code = cls.compute_gradients.__code__
|
||||
else:
|
||||
child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access
|
||||
child_code = cls.compute_gradients.__code__
|
||||
if child_code is not parent_code:
|
||||
tf.logging.warning(
|
||||
logging.warning(
|
||||
'WARNING: Calling make_optimizer_class() on class %s that overrides '
|
||||
'method compute_gradients(). Check to ensure that '
|
||||
'make_optimizer_class() does not interfere with overridden version.',
|
||||
|
@ -89,7 +78,7 @@ def make_vectorized_optimizer_class(cls):
|
|||
if gradient_tape:
|
||||
raise ValueError('When in graph mode, a tape should not be passed.')
|
||||
|
||||
batch_size = tf.shape(loss)[0]
|
||||
batch_size = tf.shape(input=loss)[0]
|
||||
if self._num_microbatches is None:
|
||||
self._num_microbatches = batch_size
|
||||
|
||||
|
@ -101,12 +90,12 @@ def make_vectorized_optimizer_class(cls):
|
|||
|
||||
if var_list is None:
|
||||
var_list = (
|
||||
tf.trainable_variables() + tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||
tf.compat.v1.trainable_variables() + tf.compat.v1.get_collection(
|
||||
tf.compat.v1.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||||
|
||||
def process_microbatch(microbatch_loss):
|
||||
"""Compute clipped grads for one microbatch."""
|
||||
microbatch_loss = tf.reduce_mean(microbatch_loss)
|
||||
microbatch_loss = tf.reduce_mean(input_tensor=microbatch_loss)
|
||||
grads, _ = zip(*super(DPOptimizerClass, self).compute_gradients(
|
||||
microbatch_loss,
|
||||
var_list,
|
||||
|
@ -122,26 +111,28 @@ def make_vectorized_optimizer_class(cls):
|
|||
# Here, we use TF primitives rather than the built-in
|
||||
# tf.clip_by_global_norm() so that operations can be vectorized
|
||||
# across microbatches.
|
||||
grads_flat = nest.flatten(grads_list)
|
||||
squared_l2_norms = [tf.reduce_sum(tf.square(g)) for g in grads_flat]
|
||||
grads_flat = tf.nest.flatten(grads_list)
|
||||
squared_l2_norms = [
|
||||
tf.reduce_sum(input_tensor=tf.square(g)) for g in grads_flat
|
||||
]
|
||||
global_norm = tf.sqrt(tf.add_n(squared_l2_norms))
|
||||
div = tf.maximum(global_norm / self._l2_norm_clip, 1.)
|
||||
clipped_flat = [g / div for g in grads_flat]
|
||||
clipped_grads = nest.pack_sequence_as(grads_list, clipped_flat)
|
||||
clipped_grads = tf.nest.pack_sequence_as(grads_list, clipped_flat)
|
||||
return clipped_grads
|
||||
|
||||
clipped_grads = tf.vectorized_map(process_microbatch, microbatch_losses)
|
||||
|
||||
def reduce_noise_normalize_batch(stacked_grads):
|
||||
summed_grads = tf.reduce_sum(stacked_grads, axis=0)
|
||||
summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0)
|
||||
noise_stddev = self._l2_norm_clip * self._noise_multiplier
|
||||
noise = tf.random.normal(tf.shape(summed_grads),
|
||||
stddev=noise_stddev)
|
||||
noise = tf.random.normal(
|
||||
tf.shape(input=summed_grads), stddev=noise_stddev)
|
||||
noised_grads = summed_grads + noise
|
||||
return noised_grads / tf.cast(self._num_microbatches, tf.float32)
|
||||
|
||||
final_grads = nest.map_structure(reduce_noise_normalize_batch,
|
||||
clipped_grads)
|
||||
final_grads = tf.nest.map_structure(reduce_noise_normalize_batch,
|
||||
clipped_grads)
|
||||
|
||||
return list(zip(final_grads, var_list))
|
||||
|
||||
|
|
|
@ -32,7 +32,8 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
|
||||
def _loss(self, val0, val1):
|
||||
"""Loss function that is minimized at the mean of the input points."""
|
||||
return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1)
|
||||
return 0.5 * tf.reduce_sum(
|
||||
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
||||
|
||||
# Parameters for testing: optimizer, num_microbatches, expected answer.
|
||||
@parameterized.named_parameters(
|
||||
|
@ -56,7 +57,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
num_microbatches=num_microbatches,
|
||||
learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
|
||||
|
@ -80,7 +81,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
num_microbatches=1,
|
||||
learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
||||
|
||||
|
@ -103,7 +104,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
num_microbatches=1,
|
||||
learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([0.0], self.evaluate(var0))
|
||||
|
||||
|
@ -116,16 +117,16 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||
self.assertNear(np.std(grads), 4.0 * 8.0, 0.5)
|
||||
|
||||
@mock.patch.object(tf, 'logging')
|
||||
@mock.patch('absl.logging.warning')
|
||||
def testComputeGradientsOverrideWarning(self, mock_logging):
|
||||
|
||||
class SimpleOptimizer(tf.train.Optimizer):
|
||||
class SimpleOptimizer(tf.compat.v1.train.Optimizer):
|
||||
|
||||
def compute_gradients(self):
|
||||
return 0
|
||||
|
||||
dp_optimizer_vectorized.make_vectorized_optimizer_class(SimpleOptimizer)
|
||||
mock_logging.warning.assert_called_once_with(
|
||||
mock_logging.assert_called_once_with(
|
||||
'WARNING: Calling make_optimizer_class() on class %s that overrides '
|
||||
'method compute_gradients(). Check to ensure that '
|
||||
'make_optimizer_class() does not interfere with overridden version.',
|
||||
|
@ -138,14 +139,14 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
preds = tf.keras.layers.Dense(
|
||||
1, activation='linear', name='dense').apply(features['x'])
|
||||
|
||||
vector_loss = tf.squared_difference(labels, preds)
|
||||
scalar_loss = tf.reduce_mean(vector_loss)
|
||||
vector_loss = tf.math.squared_difference(labels, preds)
|
||||
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
|
||||
optimizer = VectorizedDPSGD(
|
||||
l2_norm_clip=1.0,
|
||||
noise_multiplier=0.,
|
||||
num_microbatches=1,
|
||||
learning_rate=1.0)
|
||||
global_step = tf.train.get_global_step()
|
||||
global_step = tf.compat.v1.train.get_global_step()
|
||||
train_op = optimizer.minimize(loss=vector_loss, global_step=global_step)
|
||||
return tf.estimator.EstimatorSpec(
|
||||
mode=mode, loss=scalar_loss, train_op=train_op)
|
||||
|
@ -159,7 +160,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
true_weights) + true_bias + np.random.normal(
|
||||
scale=0.1, size=(200, 1)).astype(np.float32)
|
||||
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={'x': train_data},
|
||||
y=train_labels,
|
||||
batch_size=20,
|
||||
|
@ -186,7 +187,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
|||
num_microbatches=1,
|
||||
learning_rate=2.0)
|
||||
|
||||
self.evaluate(tf.global_variables_initializer())
|
||||
self.evaluate(tf.compat.v1.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([0.0], self.evaluate(var0))
|
||||
|
||||
|
|
|
@ -21,8 +21,6 @@ from __future__ import print_function
|
|||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -31,10 +29,7 @@ from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp_from_
|
|||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
else:
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
@ -97,7 +92,7 @@ def cnn_model_fn(features, labels, mode):
|
|||
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits)
|
||||
# Define mean of loss across minibatch (for reporting through tf.Estimator).
|
||||
scalar_loss = tf.reduce_mean(vector_loss)
|
||||
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
|
||||
|
||||
# Configure the training op (for TRAIN mode).
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
|
@ -125,7 +120,7 @@ def cnn_model_fn(features, labels, mode):
|
|||
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
|
||||
training_hooks = []
|
||||
opt_loss = scalar_loss
|
||||
global_step = tf.train.get_global_step()
|
||||
global_step = tf.compat.v1.train.get_global_step()
|
||||
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
|
||||
# In the following, we pass the mean of the loss (scalar_loss) rather than
|
||||
# the vector_loss because tf.estimator requires a scalar loss. This is only
|
||||
|
@ -140,7 +135,7 @@ def cnn_model_fn(features, labels, mode):
|
|||
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||
eval_metric_ops = {
|
||||
'accuracy':
|
||||
tf.metrics.accuracy(
|
||||
tf.compat.v1.metrics.accuracy(
|
||||
labels=labels,
|
||||
predictions=tf.argmax(input=logits, axis=1))
|
||||
}
|
||||
|
@ -173,7 +168,7 @@ def load_mnist():
|
|||
|
||||
|
||||
def main(unused_argv):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
|
||||
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
|
||||
raise ValueError('Number of microbatches should divide evenly batch_size')
|
||||
|
||||
|
@ -185,13 +180,13 @@ def main(unused_argv):
|
|||
model_dir=FLAGS.model_dir)
|
||||
|
||||
# Create tf.Estimator input functions for the training and test data.
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={'x': train_data},
|
||||
y=train_labels,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_epochs=FLAGS.epochs,
|
||||
shuffle=True)
|
||||
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={'x': test_data},
|
||||
y=test_labels,
|
||||
num_epochs=1,
|
||||
|
|
|
@ -19,8 +19,6 @@ from __future__ import print_function
|
|||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -28,11 +26,8 @@ from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
|||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
tf.enable_eager_execution()
|
||||
else:
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
|
||||
flags.DEFINE_boolean('dpsgd', True, 'If True, train with DP-SGD. If False, '
|
||||
'train with vanilla SGD.')
|
||||
|
@ -124,7 +119,7 @@ def main(_):
|
|||
labels=labels, logits=logits) # pylint: disable=undefined-loop-variable,cell-var-from-loop
|
||||
# If training without privacy, the loss is a scalar not a vector.
|
||||
if not FLAGS.dpsgd:
|
||||
loss = tf.reduce_mean(loss)
|
||||
loss = tf.reduce_mean(input_tensor=loss)
|
||||
return loss
|
||||
|
||||
if FLAGS.dpsgd:
|
||||
|
@ -138,7 +133,7 @@ def main(_):
|
|||
# Evaluate the model and print results
|
||||
for (_, (images, labels)) in enumerate(eval_dataset.take(-1)):
|
||||
logits = mnist_model(images, training=False)
|
||||
correct_preds = tf.equal(tf.argmax(logits, axis=1), labels)
|
||||
correct_preds = tf.equal(tf.argmax(input=logits, axis=1), labels)
|
||||
test_accuracy = np.mean(correct_preds.numpy())
|
||||
print('Test accuracy after epoch %d is: %.3f' % (epoch, test_accuracy))
|
||||
|
||||
|
|
|
@ -19,8 +19,7 @@ from __future__ import print_function
|
|||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
from absl import logging
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
@ -29,10 +28,7 @@ from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
|||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
else:
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
'dpsgd', True, 'If True, train with DP-SGD. If False, '
|
||||
|
@ -92,7 +88,7 @@ def load_mnist():
|
|||
|
||||
|
||||
def main(unused_argv):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
logging.set_verbosity(logging.INFO)
|
||||
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
|
||||
raise ValueError('Number of microbatches should divide evenly batch_size')
|
||||
|
||||
|
@ -125,7 +121,7 @@ def main(unused_argv):
|
|||
learning_rate=FLAGS.learning_rate)
|
||||
# Compute vector of per-example loss rather than its mean over a minibatch.
|
||||
loss = tf.keras.losses.CategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.losses.Reduction.NONE)
|
||||
from_logits=True, reduction=tf.compat.v1.losses.Reduction.NONE)
|
||||
else:
|
||||
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
|
||||
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
|
||||
|
|
|
@ -21,8 +21,6 @@ from __future__ import print_function
|
|||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -45,17 +43,11 @@ flags.DEFINE_integer(
|
|||
'(must evenly divide batch_size)')
|
||||
flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
NUM_TRAIN_EXAMPLES = 60000
|
||||
|
||||
|
||||
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer
|
||||
else:
|
||||
GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name
|
||||
GradientDescentOptimizer = tf.compat.v1.train.GradientDescentOptimizer
|
||||
|
||||
|
||||
def compute_epsilon(steps):
|
||||
|
@ -95,7 +87,7 @@ def cnn_model_fn(features, labels, mode):
|
|||
vector_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits)
|
||||
# Define mean of loss across minibatch (for reporting through tf.Estimator).
|
||||
scalar_loss = tf.reduce_mean(vector_loss)
|
||||
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
|
||||
|
||||
# Configure the training op (for TRAIN mode).
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
|
@ -114,7 +106,7 @@ def cnn_model_fn(features, labels, mode):
|
|||
else:
|
||||
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
|
||||
opt_loss = scalar_loss
|
||||
global_step = tf.train.get_global_step()
|
||||
global_step = tf.compat.v1.train.get_global_step()
|
||||
train_op = optimizer.minimize(loss=opt_loss, global_step=global_step)
|
||||
# In the following, we pass the mean of the loss (scalar_loss) rather than
|
||||
# the vector_loss because tf.estimator requires a scalar loss. This is only
|
||||
|
@ -128,7 +120,7 @@ def cnn_model_fn(features, labels, mode):
|
|||
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||
eval_metric_ops = {
|
||||
'accuracy':
|
||||
tf.metrics.accuracy(
|
||||
tf.compat.v1.metrics.accuracy(
|
||||
labels=labels,
|
||||
predictions=tf.argmax(input=logits, axis=1))
|
||||
}
|
||||
|
@ -161,7 +153,7 @@ def load_mnist():
|
|||
|
||||
|
||||
def main(unused_argv):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
|
||||
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
|
||||
raise ValueError('Number of microbatches should divide evenly batch_size')
|
||||
|
||||
|
@ -173,13 +165,13 @@ def main(unused_argv):
|
|||
model_dir=FLAGS.model_dir)
|
||||
|
||||
# Create tf.Estimator input functions for the training and test data.
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={'x': train_data},
|
||||
y=train_labels,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_epochs=FLAGS.epochs,
|
||||
shuffle=True)
|
||||
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={'x': test_data},
|
||||
y=test_labels,
|
||||
num_epochs=1,
|
||||
|
|
Loading…
Reference in a new issue