diff --git a/privacy/bolton/__init__.py b/privacy/bolton/__init__.py new file mode 100644 index 0000000..46bd079 --- /dev/null +++ b/privacy/bolton/__init__.py @@ -0,0 +1,14 @@ +import sys +from distutils.version import LooseVersion +import tensorflow as tf + +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + raise ImportError("Please upgrade your version of tensorflow from: {0} " + "to at least 2.0.0 to use privacy/bolton".format( + LooseVersion(tf.__version__))) +if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts. + pass +else: + from privacy.bolton.model import Bolton + from privacy.bolton.loss import Huber + from privacy.bolton.loss import BinaryCrossentropy \ No newline at end of file diff --git a/privacy/bolton/loss.py b/privacy/bolton/loss.py new file mode 100644 index 0000000..dd5d580 --- /dev/null +++ b/privacy/bolton/loss.py @@ -0,0 +1,280 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Loss functions for bolton method""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf +from tensorflow.python.keras import losses +from tensorflow.python.keras.utils import losses_utils +from tensorflow.python.framework import ops as _ops + + +class StrongConvexLoss(losses.Loss): + """ + Strong Convex Loss base class for any loss function that will be used with + Bolton model. Subclasses must be strongly convex and implement the + associated constants. They must also conform to the requirements of tf losses + (see super class) + """ + def __init__(self, + reg_lambda: float, + c: float, + radius_constant: float = 1, + reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + name: str = None, + dtype=tf.float32, + **kwargs): + """ + Args: + reg_lambda: Weight regularization constant + c: Additional constant for strongly convex convergence. Acts + as a global weight. + radius_constant: constant defining the length of the radius + reduction: reduction type to use. See super class + name: Name of the loss instance + dtype: tf datatype to use for tensor conversions. + """ + super(StrongConvexLoss, self).__init__(reduction=reduction, + name=name, + **kwargs) + self._sample_weight = tf.Variable(initial_value=c, + trainable=False, + dtype=tf.float32) + self._reg_lambda = reg_lambda + self.radius_constant = tf.Variable(initial_value=radius_constant, + trainable=False, + dtype=tf.float32) + self.dtype = dtype + + def radius(self): + """Radius of R-Ball (value to normalize weights to after each batch) + + Returns: radius + + """ + raise NotImplementedError("Radius not implemented for StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def gamma(self): + """ Gamma strongly convex + + Returns: gamma + + """ + raise NotImplementedError("Gamma not implemented for StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def beta(self, class_weight): + """Beta smoothess + + Args: + class_weight: the class weights used. + + Returns: Beta + + """ + raise NotImplementedError("Beta not implemented for StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def lipchitz_constant(self, class_weight): + """ L lipchitz continuous + + Args: + class_weight: class weights used + + Returns: L + + """ + raise NotImplementedError("lipchitz constant not implemented for " + "StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def reg_lambda(self, convert_to_tensor: bool = False): + """ returns the lambda weight regularization constant, as a tensor if + desired + + Args: + convert_to_tensor: True to convert to tensor, False to leave as + python numeric. + + Returns: reg_lambda + + """ + if convert_to_tensor: + return _ops.convert_to_tensor_v2(self._reg_lambda, dtype=self.dtype) + return self._reg_lambda + + def max_class_weight(self, class_weight): + class_weight = _ops.convert_to_tensor_v2(class_weight, dtype=self.dtype) + return tf.math.reduce_max(class_weight) + + +class Huber(StrongConvexLoss, losses.Huber): + """Strong Convex version of huber loss using l2 weight regularization. + """ + def __init__(self, + reg_lambda: float, + c: float, + radius_constant: float, + delta: float, + reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + name: str = 'huber', + dtype=tf.float32): + """Constructor. Passes arguments to StrongConvexLoss and Huber Loss. + + Args: + reg_lambda: Weight regularization constant + c: Additional constant for strongly convex convergence. Acts + as a global weight. + radius_constant: constant defining the length of the radius + delta: delta value in huber loss. When to switch from quadratic to + absolute deviation. + reduction: reduction type to use. See super class + name: Name of the loss instance + dtype: tf datatype to use for tensor conversions. + + Returns: + Loss values per sample. + """ + # self.delta = tf.Variable(initial_value=delta, trainable=False) + super(Huber, self).__init__( + reg_lambda, + c, + radius_constant, + delta=delta, + name=name, + reduction=reduction, + dtype=dtype + ) + + def call(self, y_true, y_pred): + """Compute loss + + Args: + y_true: Ground truth values. + y_pred: The predicted values. + + Returns: + Loss values per sample. + """ + return super(Huber, self).call(y_true, y_pred, **self._fn_kwargs) * \ + self._sample_weight + + def radius(self): + """See super class. + """ + return self.radius_constant / self.reg_lambda(True) + + def gamma(self): + """See super class. + """ + return self.reg_lambda(True) + + def beta(self, class_weight): + """See super class. + """ + max_class_weight = self.max_class_weight(class_weight) + return self._sample_weight * max_class_weight / \ + (self.delta * tf.Variable(initial_value=2, trainable=False)) + \ + self.reg_lambda(True) + + def lipchitz_constant(self, class_weight): + """See super class. + """ + # if class_weight is provided, + # it should be a vector of the same size of number of classes + max_class_weight = self.max_class_weight(class_weight) + lc = self._sample_weight * max_class_weight + \ + self.reg_lambda(True) * self.radius() + return lc + + +class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy): + """ + Strong Convex version of BinaryCrossentropy loss using l2 weight + regularization. + """ + def __init__(self, + reg_lambda: float, + c: float, + radius_constant: float, + from_logits: bool = True, + label_smoothing: float = 0, + reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + name: str = 'binarycrossentropy', + dtype=tf.float32): + """ + Args: + reg_lambda: Weight regularization constant + c: Additional constant for strongly convex convergence. Acts + as a global weight. + radius_constant: constant defining the length of the radius + reduction: reduction type to use. See super class + label_smoothing: amount of smoothing to perform on labels + relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) + name: Name of the loss instance + dtype: tf datatype to use for tensor conversions. + """ + super(BinaryCrossentropy, self).__init__(reg_lambda, + c, + radius_constant, + reduction=reduction, + name=name, + from_logits=from_logits, + label_smoothing=label_smoothing, + dtype=dtype + ) + self.radius_constant = radius_constant + + def call(self, y_true, y_pred): + """Compute loss + + Args: + y_true: Ground truth values. + y_pred: The predicted values. + + Returns: + Loss values per sample. + """ + loss = tf.nn.sigmoid_cross_entropy_with_logits( + labels=y_true, + logits=y_pred + ) + loss = loss * self._sample_weight + return loss + + def radius(self): + """See super class. + """ + return self.radius_constant / self.reg_lambda(True) + + def gamma(self): + """See super class. + """ + return self.reg_lambda(True) + + def beta(self, class_weight): + """See super class. + """ + max_class_weight = self.max_class_weight(class_weight) + return self._sample_weight * max_class_weight + self.reg_lambda(True) + + def lipchitz_constant(self, class_weight): + """See super class. + """ + max_class_weight = self.max_class_weight(class_weight) + return self._sample_weight * max_class_weight + \ + self.reg_lambda(True) * self.radius() diff --git a/privacy/bolton/loss_test.py b/privacy/bolton/loss_test.py new file mode 100644 index 0000000..87669fd --- /dev/null +++ b/privacy/bolton/loss_test.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function \ No newline at end of file diff --git a/privacy/bolton/model.py b/privacy/bolton/model.py new file mode 100644 index 0000000..a600374 --- /dev/null +++ b/privacy/bolton/model.py @@ -0,0 +1,402 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bolton model for bolton method of differentially private ML""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf +from tensorflow.python.keras.models import Model +from tensorflow.python.keras import optimizers +from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.framework import ops as _ops +from privacy.bolton.loss import StrongConvexLoss +from privacy.bolton.optimizer import Private + + +class Bolton(Model): + """ + Bolton episilon-delta model + Uses 4 key steps to achieve privacy guarantees: + 1. Adds noise to weights after training (output perturbation). + 2. Projects weights to R after each batch + 3. Limits learning rate + 4. Use a strongly convex loss function (see compile) + """ + def __init__(self, + n_classes, + epsilon, + noise_distribution='laplace', + weights_initializer=tf.initializers.GlorotUniform(), + seed=1, + dtype=tf.float32 + ): + """ private constructor. + + Args: + n_classes: number of output classes to predict. + epsilon: level of privacy guarantee + noise_distribution: distribution to pull weight perturbations from + weights_initializer: initializer for weights + seed: random seed to use + dtype: data type to use for tensors + """ + + class MyCustomCallback(tf.keras.callbacks.Callback): + """Custom callback for bolton training requirements. + Implements steps (see Bolton class): + 2. Projects weights to R after each batch + 3. Limits learning rate + """ + def on_train_batch_end(self, batch, logs=None): + loss = self.model.loss + self.model.optimizer.limit_learning_rate( + self.model.run_eagerly, + loss.beta(self.model.class_weight), + loss.gamma() + ) + self.model._project_weights_to_r(loss.radius(), False) + + def on_train_end(self, logs=None): + loss = self.model.loss + self.model._project_weights_to_r(loss.radius(), True) + + super(Bolton, self).__init__(name='bolton', dynamic=False) + self.n_classes = n_classes + self.output_layer = tf.keras.layers.Dense( + self.n_classes, + kernel_regularizer=tf.keras.regularizers.l2(), + kernel_initializer=weights_initializer, + ) + # if we do regularization here, we require the user to re-instantiate + # the model each time they want to + # change lambda, unless we standardize modifying it later at .compile + self.force = False + self.noise_distribution = noise_distribution + self.epsilon = epsilon + self.seed = seed + self.__in_fit = False + self._callback = MyCustomCallback() + self._dtype = dtype + + def call(self, inputs): + """Forward pass of network + + Args: + inputs: inputs to neural network + + Returns: + + """ + return self.output_layer(inputs) + + def compile(self, + optimizer='SGD', + loss=None, + metrics=None, + loss_weights=None, + sample_weight_mode=None, + weighted_metrics=None, + target_tensors=None, + distribute=None, + **kwargs): + """See super class. Default optimizer used in Bolton method is SGD. + + """ + if not isinstance(loss, StrongConvexLoss): + raise ValueError("Loss must be subclassed from StrongConvexLoss") + self.output_layer.kernel_regularizer.l2 = loss.reg_lambda() + if not isinstance(optimizer, Private): + optimizer = optimizers.get(optimizer) + if isinstance(self.optimizer, trackable.Trackable): + self._track_trackable( + self.optimizer, name='optimizer', overwrite=True + ) + optimizer = Private(optimizer) + + super(Bolton, self).compile(optimizer, + loss=loss, + metrics=metrics, + loss_weights=loss_weights, + sample_weight_mode=sample_weight_mode, + weighted_metrics=weighted_metrics, + target_tensors=target_tensors, + distribute=distribute, + **kwargs + ) + + def _post_fit(self, x, n_samples): + """Implements 1-time weight changes needed for Bolton method. + In this case, specifically implements the noise addition + assuming a strongly convex function. + + Args: + x: inputs + n_samples: number of samples in the inputs. In case the number + cannot be readily determined by inspecting x. + + Returns: + + """ + if n_samples is not None: + data_size = n_samples + elif hasattr(x, 'shape'): + data_size = x.shape[0] + elif hasattr(x, "__len__"): + data_size = len(x) + else: + if n_samples is None: + raise ValueError("Unable to detect the number of training " + "samples and n_smaples was None. " + "either pass a dataset with a .shape or " + "__len__ attribute or explicitly pass the " + "number of samples as n_smaples.") + data_size = n_samples + + for layer in self._layers: + layer.kernel = layer.kernel + self._get_noise( + self.noise_distribution, + data_size + ) + + def fit(self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose=1, + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_freq=1, + max_queue_size=10, + workers=1, + use_multiprocessing=False, + n_samples=None, + **kwargs): + """Reroutes to super fit with additional Bolton delta-epsilon privacy + requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1 + Requirements are as follows: + 1. Adds noise to weights after training (output perturbation). + 2. Projects weights to R after each batch + 3. Limits learning rate + 4. Use a strongly convex loss function (see compile) + See super implementation for more details. + + Args: + n_samples: the number of individual samples in x. + + Returns: + + """ + self.__in_fit = True + cb = [self._callback] + if callbacks is not None: + cb.extend(callbacks) + callbacks = cb + if class_weight is None: + class_weight = self.calculate_class_weights(class_weight) + self.class_weight = class_weight + out = super(Bolton, self).fit(x=x, + y=y, + batch_size=batch_size, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + validation_split=validation_split, + validation_data=validation_data, + shuffle=shuffle, + class_weight=class_weight, + sample_weight=sample_weight, + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps, + validation_freq=validation_freq, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + **kwargs + ) + self._post_fit(x, n_samples) + self.__in_fit = False + return out + + def fit_generator(self, + generator, + steps_per_epoch=None, + epochs=1, + verbose=1, + callbacks=None, + validation_data=None, + validation_steps=None, + validation_freq=1, + class_weight=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False, + shuffle=True, + initial_epoch=0, + n_samples=None + ): + """ + This method is the same as fit except for when the passed dataset + is a generator. See super method and fit for more details. + Args: + n_samples: number of individual samples in x + + """ + if class_weight is None: + class_weight = self.calculate_class_weights(class_weight) + self.class_weight = class_weight + out = super(Bolton, self).fit_generator( + generator, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + validation_data=validation_data, + validation_steps=validation_steps, + validation_freq=validation_freq, + class_weight=class_weight, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + shuffle=shuffle, + initial_epoch=initial_epoch + ) + if not self.__in_fit: + self._post_fit(generator, n_samples) + return out + + def calculate_class_weights(self, + class_weights=None, + class_counts=None, + num_classes=None + ): + """ + Calculates class weighting to be used in training. Can be on + Args: + class_weights: str specifying type, array giving weights, or None. + class_counts: If class_weights is not None, then the number of + samples for each class + num_classes: If class_weights is not None, then the number of + classes. + Returns: class_weights as 1D tensor, to be passed to model's fit method. + + """ + # Value checking + class_keys = ['balanced'] + is_string = False + if isinstance(class_weights, str): + is_string = True + if class_weights not in class_keys: + raise ValueError("Detected string class_weights with " + "value: {0}, which is not one of {1}." + "Please select a valid class_weight type" + "or pass an array".format(class_weights, + class_keys)) + if class_counts is None: + raise ValueError("Class counts must be provided if using" + "class_weights=%s" % class_weights) + if num_classes is None: + raise ValueError("Class counts must be provided if using" + "class_weights=%s" % class_weights) + elif class_weights is not None: + if num_classes is None: + raise ValueError("You must pass a value for num_classes if" + "creating an array of class_weights") + # performing class weight calculation + if class_weights is None: + class_weights = 1 + elif is_string and class_weights == 'balanced': + num_samples = sum(class_counts) + class_weights = tf.Variable( + num_samples / (num_classes * class_counts), + dtype=self._dtype + ) + else: + class_weights = _ops.convert_to_tensor_v2(class_weights) + if len(class_weights.shape) != 1: + raise ValueError("Detected class_weights shape: {0} instead of " + "1D array".format(class_weights.shape)) + if class_weights.shape[0] != num_classes: + raise ValueError( + "Detected array length: {0} instead of: {1}".format( + class_weights.shape[0], + num_classes + ) + ) + return class_weights + + def _project_weights_to_r(self, r, force=False): + """helper method to normalize the weights to the R-ball. + + Args: + r: radius of "R-Ball". Scalar to normalize to. + force: True to normalize regardless of previous weight values. + False to check if weights > R-ball and only normalize then. + + Returns: + + """ + for layer in self._layers: + weight_norm = tf.norm(layer.kernel, axis=0) + if force: + layer.kernel = layer.kernel / (weight_norm / r) + elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0: + layer.kernel = layer.kernel / (weight_norm / r) + + def _get_noise(self, distribution, data_size): + """Sample noise to be added to weights for privacy guarantee + + Args: + distribution: the distribution type to pull noise from + data_size: the number of samples + + Returns: noise in shape of layer's weights to be added to the weights. + + """ + distribution = distribution.lower() + input_dim = self._layers[0].kernel.numpy().shape[0] + loss = self.loss + if distribution == 'laplace': + per_class_epsilon = self.epsilon / (self.n_classes) + l2_sensitivity = (2 * + loss.lipchitz_constant(self.class_weight)) / \ + (loss.gamma() * data_size) + unit_vector = tf.random.normal(shape=(input_dim, self.n_classes), + mean=0, + seed=1, + stddev=1.0, + dtype=self._dtype) + unit_vector = unit_vector / tf.math.sqrt( + tf.reduce_sum(tf.math.square(unit_vector), axis=0) + ) + + beta = l2_sensitivity / per_class_epsilon + alpha = input_dim # input_dim + gamma = tf.random.gamma([self.n_classes], + alpha, + beta=1 / beta, + seed=1, + dtype=self._dtype) + return unit_vector * gamma + raise NotImplementedError("distribution: {0} is not " + "currently supported".format(distribution)) diff --git a/privacy/bolton/model_test.py b/privacy/bolton/model_test.py new file mode 100644 index 0000000..87669fd --- /dev/null +++ b/privacy/bolton/model_test.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function \ No newline at end of file diff --git a/privacy/bolton/optimizer.py b/privacy/bolton/optimizer.py new file mode 100644 index 0000000..f8af390 --- /dev/null +++ b/privacy/bolton/optimizer.py @@ -0,0 +1,173 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Private Optimizer for bolton method""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.keras.optimizer_v2 import optimizer_v2 + +_private_attributes = ['_internal_optimizer', 'dtype'] + + +class Private(optimizer_v2.OptimizerV2): + """ + Private optimizer wraps another tf optimizer to be used + as the visible optimizer to the tf model. No matter the optimizer + passed, "Private" enables the bolton model to control the learning rate + based on the strongly convex loss. + """ + def __init__(self, + optimizer: optimizer_v2.OptimizerV2, + dtype=tf.float32 + ): + """Constructor. + + Args: + optimizer: Optimizer_v2 or subclass to be used as the optimizer + (wrapped). + """ + self._internal_optimizer = optimizer + self.dtype = dtype + + def get_config(self): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.get_config() + + def limit_learning_rate(self, is_eager, beta, gamma): + """Implements learning rate limitation that is required by the bolton + method for sensitivity bounding of the strongly convex function. + Sets the learning rate to the min(1/beta, 1/(gamma*t)) + + Args: + is_eager: Whether the model is running in eager mode + beta: loss function beta-smoothness + gamma: loss function gamma-strongly convex + + Returns: None + + """ + numerator = tf.Variable(initial_value=1, dtype=self.dtype) + t = tf.cast(self._iterations, self.dtype) + # will exist on the internal optimizer + pred = numerator / beta < numerator / (gamma * t) + if is_eager: # check eagerly + if pred: + self.learning_rate = numerator / beta + else: + self.learning_rate = numerator / (gamma * t) + else: + if pred: + self.learning_rate = numerator / beta + else: + self.learning_rate = numerator / (gamma * t) + + def from_config(self, config, custom_objects=None): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.from_config( + config, + custom_objects=custom_objects + ) + + def __getattr__(self, name): + """return _internal_optimizer off self instance, and everything else + from the _internal_optimizer instance. + + Args: + name: + + Returns: attribute from Private if specified to come from self, else + from _internal_optimizer. + + """ + if name in _private_attributes: + return getattr(self, name) + optim = object.__getattribute__(self, '_internal_optimizer') + return object.__getattribute__(optim, name) + + def __setattr__(self, key, value): + """ Set attribute to self instance if its the internal optimizer. + Reroute everything else to the _internal_optimizer. + + Args: + key: attribute name + value: attribute value + + Returns: + + """ + if key in _private_attributes: + object.__setattr__(self, key, value) + else: + setattr(self._internal_optimizer, key, value) + + def _resource_apply_dense(self, grad, handle): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer._resource_apply_dense(grad, handle) + + def _resource_apply_sparse(self, grad, handle, indices): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer._resource_apply_sparse( + grad, + handle, + indices + ) + + def get_updates(self, loss, params): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.get_updates(loss, params) + + def apply_gradients(self, grads_and_vars, name: str = None): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.apply_gradients( + grads_and_vars, + name=name + ) + + def minimize(self, + loss, + var_list, + grad_loss: bool = None, + name: str = None + ): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.minimize( + loss, + var_list, + grad_loss, + name + ) + + def _compute_gradients(self, loss, var_list, grad_loss=None): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer._compute_gradients( + loss, + var_list, + grad_loss=grad_loss + ) + + def get_gradients(self, loss, params): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.get_gradients(loss, params) diff --git a/privacy/bolton/optimizer_test.py b/privacy/bolton/optimizer_test.py new file mode 100644 index 0000000..ec8de48 --- /dev/null +++ b/privacy/bolton/optimizer_test.py @@ -0,0 +1,9 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.platform import test +from tensorflow.python.keras import keras_parameterized +from privacy.bolton import model +