diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py index 7e94dfa..12fa53f 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py @@ -136,8 +136,8 @@ def embedding_layer_computation(layer_instance, inputs): `layer_instance(inputs)` returns a valid output. Returns: - A `tuple` `(base_vars, None, sqr_norm_fn)`, `base_vars` is the - intermediate Tensor used in the chain-rule / "fast" clipping trick, and + A `tuple` `(base_vars, transform, sqr_norm_fn)`, `base_vars` is the + intermediate Tensor used in the chain-rule / "fast" clipping trick, `sqr_norm_fn` is a function that takes one input, a `tf.Tensor` that represents the output of the call `tape.gradient(summed_loss, base_vars)` where `tape` is a `tf.GradientTape` instance that records the dense @@ -148,41 +148,72 @@ def embedding_layer_computation(layer_instance, inputs): """ if hasattr(layer_instance, "sparse"): # for backwards compatibility if layer_instance.sparse: - raise NotImplementedError("Sparse output vectors are not supported.") - if len(inputs[0].shape) != 2: - raise NotImplementedError("Only 2D embedding inputs are supported.") - # The logic below is applied to properly handle repeated embedding indices. - # Specifically, sqr_grad_norms will contain the total counts of each embedding - # index (see how it is processed in the combine_pre_and_post_sqr_norms() - # function in clip_grads.py). An example is as follows: - # - # inputs = - # [[0 0 0 1 2 2], - # [0 2 2 2 1 1]] - # - # counts = - # [[3 1 2] - # [1 2 3]] - # - # input_counts = - # [[3 3 3 1 2 2], - # [1 3 3 3 2 2]] - # - base_vars = layer_instance(*inputs) - def sqr_norm_fn(base_vars_grads): - indices = tf.cast(*inputs, tf.int32) - if isinstance(indices, tf.SparseTensor): - indices = tf.sparse.to_dense(indices) - counts = tf.math.bincount(indices, axis=-1) - input_counts = tf.expand_dims( - tf.cast(tf.gather(counts, indices, batch_dims=1), base_vars.dtype), - axis=-1, - ) - scaled_grads = input_counts * tf.square(base_vars_grads) - reduction_axes = tf.range(1, tf.rank(scaled_grads)) - return tf.reduce_sum(scaled_grads, axis=reduction_axes) + raise NotImplementedError("Sparse output tensors are not supported.") + if isinstance(inputs, tf.SparseTensor): + raise NotImplementedError("Sparse input tensors are not supported.") - return base_vars, None, sqr_norm_fn + # Disable experimental features. + if hasattr(layer_instance, "_use_one_hot_matmul"): + if layer_instance._use_one_hot_matmul: # pylint: disable=protected-access + raise NotImplementedError( + "The experimental embedding feature" + "'_use_one_hot_matmul' is not supported." + ) + input_ids = tf.cast(*inputs, tf.int32) + base_vars = layer_instance.trainable_variables[0] + + def lookup_inputs(embeddings): + return tf.nn.embedding_lookup(embeddings, input_ids) + + def sqr_norm_fn(base_vars_grads): + # Get a 1D tensor of the row indices. + nrows = tf.shape(input_ids)[0] + if isinstance(input_ids, tf.RaggedTensor): + row_indices = tf.expand_dims( + input_ids.merge_dims(1, -1).value_rowids(), axis=-1 + ) + elif isinstance(input_ids, tf.Tensor): + ncols = tf.reduce_prod(tf.shape(input_ids)[1:]) + repeats = tf.repeat(ncols, nrows) + row_indices = tf.reshape(tf.repeat(tf.range(nrows), repeats), [-1, 1]) + else: + raise NotImplementedError( + "Cannot parse input_ids of type %s" % input_ids.__class__.__name__ + ) + # Sum-reduce the `IndexSlices` that is the result of a `tape.gradient()` + # call. The sum is reduced by the repeated embedding indices and batch + # index. It is adapted from the logic in: + # tf.keras.optimizers.legacy.optimizer_v2._deduplicate_indexed_slices + if not isinstance(base_vars_grads, tf.IndexedSlices): + raise NotImplementedError( + "Cannot parse embedding gradients of type: %s" + % base_vars_grads.__class__.__name__ + ) + slice_indices = tf.expand_dims(base_vars_grads.indices, axis=-1) + paired_indices = tf.concat( + [tf.cast(row_indices, tf.int64), tf.cast(slice_indices, tf.int64)], + axis=1, + ) + (unique_paired_indices, new_index_positions) = tf.raw_ops.UniqueV2( + x=paired_indices, axis=[0] + ) + unique_batch_ids = unique_paired_indices[:, 0] + summed_gradients = tf.math.unsorted_segment_sum( + base_vars_grads.values, + new_index_positions, + tf.shape(unique_paired_indices)[0], + ) + # Compute the squared gradient norms at the per-example level. + sqr_gradient_sum = tf.reduce_sum(tf.square(summed_gradients), axis=1) + summed_data_range = tf.range(tf.shape(sqr_gradient_sum)[0]) + return tf.sparse.segment_sum( + sqr_gradient_sum, + summed_data_range, + tf.sort(unique_batch_ids), + num_segments=nrows, + ) # fill in empty inputs + + return base_vars, lookup_inputs, sqr_norm_fn # ==============================================================================