Update privacy/keras_models.
This commit is contained in:
parent
7a00a1cfef
commit
7dad2d18e8
2 changed files with 59 additions and 30 deletions
59
tensorflow_privacy/privacy/keras_models/dp_keras_model.py
Normal file
59
tensorflow_privacy/privacy/keras_models/dp_keras_model.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def make_dp_model_class(cls):
|
||||
class DPModelClass(cls):
|
||||
def __init__(self, l2_norm_clip, noise_multiplier, use_xla=True, *args, **kwargs):
|
||||
super(DPModelClass, self).__init__(*args, **kwargs)
|
||||
self._l2_norm_clip = l2_norm_clip
|
||||
self._noise_multiplier = noise_multiplier
|
||||
|
||||
if use_xla:
|
||||
self.train_step = tf.function(
|
||||
self.train_step, experimental_compile=True)
|
||||
|
||||
def process_per_example_grads(self, grads):
|
||||
grads_flat = tf.nest.flatten(grads)
|
||||
squared_l2_norms = [tf.reduce_sum(
|
||||
input_tensor=tf.square(g)) for g in grads_flat]
|
||||
global_norm = tf.sqrt(tf.add_n(squared_l2_norms))
|
||||
div = tf.maximum(global_norm / self._l2_norm_clip, 1.)
|
||||
clipped_flat = [g / div for g in grads_flat]
|
||||
return tf.nest.pack_sequence_as(grads, clipped_flat)
|
||||
|
||||
def reduce_per_example_grads(self, stacked_grads):
|
||||
summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0)
|
||||
noise_stddev = self._l2_norm_clip * self._noise_multiplier
|
||||
noise = tf.random.normal(
|
||||
tf.shape(input=summed_grads), stddev=noise_stddev)
|
||||
noised_grads = summed_grads + noise
|
||||
return noised_grads / tf.cast(stacked_grads.shape[0], noised_grads.dtype)
|
||||
|
||||
def compute_per_example_grads(self, data):
|
||||
x, y = data
|
||||
with tf.GradientTape() as tape:
|
||||
# We need to add the extra dimension to x and y because model
|
||||
# expects batched input.
|
||||
y_pred = self(x[None], training=True)
|
||||
loss = self.compiled_loss(y[None], y_pred,
|
||||
regularization_losses=self.losses)
|
||||
|
||||
grads_list = tape.gradient(loss, self.trainable_variables)
|
||||
clipped_grads = self.process_per_example_grads(grads_list)
|
||||
return tf.squeeze(y_pred, axis=0), loss, clipped_grads
|
||||
|
||||
def train_step(self, data):
|
||||
x, y = data
|
||||
y_pred, per_eg_loss, per_eg_grads = tf.vectorized_map(
|
||||
self.compute_per_example_grads, data)
|
||||
loss = tf.reduce_mean(per_eg_loss, axis=0)
|
||||
grads = tf.nest.map_structure(self.reduce_per_example_grads, per_eg_grads)
|
||||
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
||||
self.compiled_metrics.update_state(y, y_pred)
|
||||
return {m.name: m.result() for m in self.metrics}
|
||||
|
||||
return DPModelClass
|
||||
|
||||
|
||||
DPModel = make_dp_model_class(tf.keras.Model)
|
||||
DPSequential = make_dp_model_class(tf.keras.Sequential)
|
|
@ -1,30 +0,0 @@
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def make_dp_model_class(cls):
|
||||
class DPModelClass(cls):
|
||||
def __init__(self, l2_norm_clip, noise_multiplier, use_xla=True, *args, **kwargs):
|
||||
super(DPModelClass, self).__init__(*args, **kwargs)
|
||||
self._l2_norm_clip = l2_norm_clip
|
||||
self._noise_multiplier = noise_multiplier
|
||||
|
||||
if use_xla:
|
||||
self.train_step = tf.function(
|
||||
self.train_step, experimental_compile=True)
|
||||
|
||||
def process_per_example_grads(self, grads):
|
||||
grads_flat = tf.nest.flatten(grads)
|
||||
squared_l2_norms = [tf.reduce_sum(
|
||||
input_tensor=tf.square(g)) for g in grads_flat]
|
||||
global_norm = tf.sqrt(tf.add_n(squared_l2_norms))
|
||||
div = tf.maximum(global_norm / self._l2_norm_clip, 1.)
|
||||
clipped_flat = [g / div for g in grads_flat]
|
||||
return tf.nest.pack_sequence_as(grads, clipped_flat)
|
||||
|
||||
def reduce_per_example_grads(self, stacked_grads):
|
||||
summed_grads = tf.reduce_sum(input_tensor=stacked_grads, axis=0)
|
||||
noise_stddev = self._l2_norm_clip * self._noise_multiplier
|
||||
noise = tf.random.normal(
|
||||
tf.shape(input=summed_grads), stddev=noise_stddev)
|
||||
noised_grads = summed_grads + noise
|
||||
return noised_grads / tf.cast(stacked_grads.shape[0], tf.float32)
|
Loading…
Reference in a new issue