From 5f46927747ed9a26b03cb4ed63216dc6757a104f Mon Sep 17 00:00:00 2001 From: Christopher Choquette Choo Date: Wed, 5 Jun 2019 17:06:02 -0400 Subject: [PATCH 1/9] Working bolton model without unit tests. --- privacy/bolton/__init__.py | 14 ++ privacy/bolton/loss.py | 280 +++++++++++++++++++++ privacy/bolton/loss_test.py | 3 + privacy/bolton/model.py | 402 +++++++++++++++++++++++++++++++ privacy/bolton/model_test.py | 3 + privacy/bolton/optimizer.py | 173 +++++++++++++ privacy/bolton/optimizer_test.py | 9 + 7 files changed, 884 insertions(+) create mode 100644 privacy/bolton/__init__.py create mode 100644 privacy/bolton/loss.py create mode 100644 privacy/bolton/loss_test.py create mode 100644 privacy/bolton/model.py create mode 100644 privacy/bolton/model_test.py create mode 100644 privacy/bolton/optimizer.py create mode 100644 privacy/bolton/optimizer_test.py diff --git a/privacy/bolton/__init__.py b/privacy/bolton/__init__.py new file mode 100644 index 0000000..46bd079 --- /dev/null +++ b/privacy/bolton/__init__.py @@ -0,0 +1,14 @@ +import sys +from distutils.version import LooseVersion +import tensorflow as tf + +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + raise ImportError("Please upgrade your version of tensorflow from: {0} " + "to at least 2.0.0 to use privacy/bolton".format( + LooseVersion(tf.__version__))) +if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts. + pass +else: + from privacy.bolton.model import Bolton + from privacy.bolton.loss import Huber + from privacy.bolton.loss import BinaryCrossentropy \ No newline at end of file diff --git a/privacy/bolton/loss.py b/privacy/bolton/loss.py new file mode 100644 index 0000000..dd5d580 --- /dev/null +++ b/privacy/bolton/loss.py @@ -0,0 +1,280 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Loss functions for bolton method""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf +from tensorflow.python.keras import losses +from tensorflow.python.keras.utils import losses_utils +from tensorflow.python.framework import ops as _ops + + +class StrongConvexLoss(losses.Loss): + """ + Strong Convex Loss base class for any loss function that will be used with + Bolton model. Subclasses must be strongly convex and implement the + associated constants. They must also conform to the requirements of tf losses + (see super class) + """ + def __init__(self, + reg_lambda: float, + c: float, + radius_constant: float = 1, + reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + name: str = None, + dtype=tf.float32, + **kwargs): + """ + Args: + reg_lambda: Weight regularization constant + c: Additional constant for strongly convex convergence. Acts + as a global weight. + radius_constant: constant defining the length of the radius + reduction: reduction type to use. See super class + name: Name of the loss instance + dtype: tf datatype to use for tensor conversions. + """ + super(StrongConvexLoss, self).__init__(reduction=reduction, + name=name, + **kwargs) + self._sample_weight = tf.Variable(initial_value=c, + trainable=False, + dtype=tf.float32) + self._reg_lambda = reg_lambda + self.radius_constant = tf.Variable(initial_value=radius_constant, + trainable=False, + dtype=tf.float32) + self.dtype = dtype + + def radius(self): + """Radius of R-Ball (value to normalize weights to after each batch) + + Returns: radius + + """ + raise NotImplementedError("Radius not implemented for StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def gamma(self): + """ Gamma strongly convex + + Returns: gamma + + """ + raise NotImplementedError("Gamma not implemented for StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def beta(self, class_weight): + """Beta smoothess + + Args: + class_weight: the class weights used. + + Returns: Beta + + """ + raise NotImplementedError("Beta not implemented for StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def lipchitz_constant(self, class_weight): + """ L lipchitz continuous + + Args: + class_weight: class weights used + + Returns: L + + """ + raise NotImplementedError("lipchitz constant not implemented for " + "StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def reg_lambda(self, convert_to_tensor: bool = False): + """ returns the lambda weight regularization constant, as a tensor if + desired + + Args: + convert_to_tensor: True to convert to tensor, False to leave as + python numeric. + + Returns: reg_lambda + + """ + if convert_to_tensor: + return _ops.convert_to_tensor_v2(self._reg_lambda, dtype=self.dtype) + return self._reg_lambda + + def max_class_weight(self, class_weight): + class_weight = _ops.convert_to_tensor_v2(class_weight, dtype=self.dtype) + return tf.math.reduce_max(class_weight) + + +class Huber(StrongConvexLoss, losses.Huber): + """Strong Convex version of huber loss using l2 weight regularization. + """ + def __init__(self, + reg_lambda: float, + c: float, + radius_constant: float, + delta: float, + reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + name: str = 'huber', + dtype=tf.float32): + """Constructor. Passes arguments to StrongConvexLoss and Huber Loss. + + Args: + reg_lambda: Weight regularization constant + c: Additional constant for strongly convex convergence. Acts + as a global weight. + radius_constant: constant defining the length of the radius + delta: delta value in huber loss. When to switch from quadratic to + absolute deviation. + reduction: reduction type to use. See super class + name: Name of the loss instance + dtype: tf datatype to use for tensor conversions. + + Returns: + Loss values per sample. + """ + # self.delta = tf.Variable(initial_value=delta, trainable=False) + super(Huber, self).__init__( + reg_lambda, + c, + radius_constant, + delta=delta, + name=name, + reduction=reduction, + dtype=dtype + ) + + def call(self, y_true, y_pred): + """Compute loss + + Args: + y_true: Ground truth values. + y_pred: The predicted values. + + Returns: + Loss values per sample. + """ + return super(Huber, self).call(y_true, y_pred, **self._fn_kwargs) * \ + self._sample_weight + + def radius(self): + """See super class. + """ + return self.radius_constant / self.reg_lambda(True) + + def gamma(self): + """See super class. + """ + return self.reg_lambda(True) + + def beta(self, class_weight): + """See super class. + """ + max_class_weight = self.max_class_weight(class_weight) + return self._sample_weight * max_class_weight / \ + (self.delta * tf.Variable(initial_value=2, trainable=False)) + \ + self.reg_lambda(True) + + def lipchitz_constant(self, class_weight): + """See super class. + """ + # if class_weight is provided, + # it should be a vector of the same size of number of classes + max_class_weight = self.max_class_weight(class_weight) + lc = self._sample_weight * max_class_weight + \ + self.reg_lambda(True) * self.radius() + return lc + + +class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy): + """ + Strong Convex version of BinaryCrossentropy loss using l2 weight + regularization. + """ + def __init__(self, + reg_lambda: float, + c: float, + radius_constant: float, + from_logits: bool = True, + label_smoothing: float = 0, + reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + name: str = 'binarycrossentropy', + dtype=tf.float32): + """ + Args: + reg_lambda: Weight regularization constant + c: Additional constant for strongly convex convergence. Acts + as a global weight. + radius_constant: constant defining the length of the radius + reduction: reduction type to use. See super class + label_smoothing: amount of smoothing to perform on labels + relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) + name: Name of the loss instance + dtype: tf datatype to use for tensor conversions. + """ + super(BinaryCrossentropy, self).__init__(reg_lambda, + c, + radius_constant, + reduction=reduction, + name=name, + from_logits=from_logits, + label_smoothing=label_smoothing, + dtype=dtype + ) + self.radius_constant = radius_constant + + def call(self, y_true, y_pred): + """Compute loss + + Args: + y_true: Ground truth values. + y_pred: The predicted values. + + Returns: + Loss values per sample. + """ + loss = tf.nn.sigmoid_cross_entropy_with_logits( + labels=y_true, + logits=y_pred + ) + loss = loss * self._sample_weight + return loss + + def radius(self): + """See super class. + """ + return self.radius_constant / self.reg_lambda(True) + + def gamma(self): + """See super class. + """ + return self.reg_lambda(True) + + def beta(self, class_weight): + """See super class. + """ + max_class_weight = self.max_class_weight(class_weight) + return self._sample_weight * max_class_weight + self.reg_lambda(True) + + def lipchitz_constant(self, class_weight): + """See super class. + """ + max_class_weight = self.max_class_weight(class_weight) + return self._sample_weight * max_class_weight + \ + self.reg_lambda(True) * self.radius() diff --git a/privacy/bolton/loss_test.py b/privacy/bolton/loss_test.py new file mode 100644 index 0000000..87669fd --- /dev/null +++ b/privacy/bolton/loss_test.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function \ No newline at end of file diff --git a/privacy/bolton/model.py b/privacy/bolton/model.py new file mode 100644 index 0000000..a600374 --- /dev/null +++ b/privacy/bolton/model.py @@ -0,0 +1,402 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bolton model for bolton method of differentially private ML""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf +from tensorflow.python.keras.models import Model +from tensorflow.python.keras import optimizers +from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.framework import ops as _ops +from privacy.bolton.loss import StrongConvexLoss +from privacy.bolton.optimizer import Private + + +class Bolton(Model): + """ + Bolton episilon-delta model + Uses 4 key steps to achieve privacy guarantees: + 1. Adds noise to weights after training (output perturbation). + 2. Projects weights to R after each batch + 3. Limits learning rate + 4. Use a strongly convex loss function (see compile) + """ + def __init__(self, + n_classes, + epsilon, + noise_distribution='laplace', + weights_initializer=tf.initializers.GlorotUniform(), + seed=1, + dtype=tf.float32 + ): + """ private constructor. + + Args: + n_classes: number of output classes to predict. + epsilon: level of privacy guarantee + noise_distribution: distribution to pull weight perturbations from + weights_initializer: initializer for weights + seed: random seed to use + dtype: data type to use for tensors + """ + + class MyCustomCallback(tf.keras.callbacks.Callback): + """Custom callback for bolton training requirements. + Implements steps (see Bolton class): + 2. Projects weights to R after each batch + 3. Limits learning rate + """ + def on_train_batch_end(self, batch, logs=None): + loss = self.model.loss + self.model.optimizer.limit_learning_rate( + self.model.run_eagerly, + loss.beta(self.model.class_weight), + loss.gamma() + ) + self.model._project_weights_to_r(loss.radius(), False) + + def on_train_end(self, logs=None): + loss = self.model.loss + self.model._project_weights_to_r(loss.radius(), True) + + super(Bolton, self).__init__(name='bolton', dynamic=False) + self.n_classes = n_classes + self.output_layer = tf.keras.layers.Dense( + self.n_classes, + kernel_regularizer=tf.keras.regularizers.l2(), + kernel_initializer=weights_initializer, + ) + # if we do regularization here, we require the user to re-instantiate + # the model each time they want to + # change lambda, unless we standardize modifying it later at .compile + self.force = False + self.noise_distribution = noise_distribution + self.epsilon = epsilon + self.seed = seed + self.__in_fit = False + self._callback = MyCustomCallback() + self._dtype = dtype + + def call(self, inputs): + """Forward pass of network + + Args: + inputs: inputs to neural network + + Returns: + + """ + return self.output_layer(inputs) + + def compile(self, + optimizer='SGD', + loss=None, + metrics=None, + loss_weights=None, + sample_weight_mode=None, + weighted_metrics=None, + target_tensors=None, + distribute=None, + **kwargs): + """See super class. Default optimizer used in Bolton method is SGD. + + """ + if not isinstance(loss, StrongConvexLoss): + raise ValueError("Loss must be subclassed from StrongConvexLoss") + self.output_layer.kernel_regularizer.l2 = loss.reg_lambda() + if not isinstance(optimizer, Private): + optimizer = optimizers.get(optimizer) + if isinstance(self.optimizer, trackable.Trackable): + self._track_trackable( + self.optimizer, name='optimizer', overwrite=True + ) + optimizer = Private(optimizer) + + super(Bolton, self).compile(optimizer, + loss=loss, + metrics=metrics, + loss_weights=loss_weights, + sample_weight_mode=sample_weight_mode, + weighted_metrics=weighted_metrics, + target_tensors=target_tensors, + distribute=distribute, + **kwargs + ) + + def _post_fit(self, x, n_samples): + """Implements 1-time weight changes needed for Bolton method. + In this case, specifically implements the noise addition + assuming a strongly convex function. + + Args: + x: inputs + n_samples: number of samples in the inputs. In case the number + cannot be readily determined by inspecting x. + + Returns: + + """ + if n_samples is not None: + data_size = n_samples + elif hasattr(x, 'shape'): + data_size = x.shape[0] + elif hasattr(x, "__len__"): + data_size = len(x) + else: + if n_samples is None: + raise ValueError("Unable to detect the number of training " + "samples and n_smaples was None. " + "either pass a dataset with a .shape or " + "__len__ attribute or explicitly pass the " + "number of samples as n_smaples.") + data_size = n_samples + + for layer in self._layers: + layer.kernel = layer.kernel + self._get_noise( + self.noise_distribution, + data_size + ) + + def fit(self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose=1, + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_freq=1, + max_queue_size=10, + workers=1, + use_multiprocessing=False, + n_samples=None, + **kwargs): + """Reroutes to super fit with additional Bolton delta-epsilon privacy + requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1 + Requirements are as follows: + 1. Adds noise to weights after training (output perturbation). + 2. Projects weights to R after each batch + 3. Limits learning rate + 4. Use a strongly convex loss function (see compile) + See super implementation for more details. + + Args: + n_samples: the number of individual samples in x. + + Returns: + + """ + self.__in_fit = True + cb = [self._callback] + if callbacks is not None: + cb.extend(callbacks) + callbacks = cb + if class_weight is None: + class_weight = self.calculate_class_weights(class_weight) + self.class_weight = class_weight + out = super(Bolton, self).fit(x=x, + y=y, + batch_size=batch_size, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + validation_split=validation_split, + validation_data=validation_data, + shuffle=shuffle, + class_weight=class_weight, + sample_weight=sample_weight, + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps, + validation_freq=validation_freq, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + **kwargs + ) + self._post_fit(x, n_samples) + self.__in_fit = False + return out + + def fit_generator(self, + generator, + steps_per_epoch=None, + epochs=1, + verbose=1, + callbacks=None, + validation_data=None, + validation_steps=None, + validation_freq=1, + class_weight=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False, + shuffle=True, + initial_epoch=0, + n_samples=None + ): + """ + This method is the same as fit except for when the passed dataset + is a generator. See super method and fit for more details. + Args: + n_samples: number of individual samples in x + + """ + if class_weight is None: + class_weight = self.calculate_class_weights(class_weight) + self.class_weight = class_weight + out = super(Bolton, self).fit_generator( + generator, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + validation_data=validation_data, + validation_steps=validation_steps, + validation_freq=validation_freq, + class_weight=class_weight, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + shuffle=shuffle, + initial_epoch=initial_epoch + ) + if not self.__in_fit: + self._post_fit(generator, n_samples) + return out + + def calculate_class_weights(self, + class_weights=None, + class_counts=None, + num_classes=None + ): + """ + Calculates class weighting to be used in training. Can be on + Args: + class_weights: str specifying type, array giving weights, or None. + class_counts: If class_weights is not None, then the number of + samples for each class + num_classes: If class_weights is not None, then the number of + classes. + Returns: class_weights as 1D tensor, to be passed to model's fit method. + + """ + # Value checking + class_keys = ['balanced'] + is_string = False + if isinstance(class_weights, str): + is_string = True + if class_weights not in class_keys: + raise ValueError("Detected string class_weights with " + "value: {0}, which is not one of {1}." + "Please select a valid class_weight type" + "or pass an array".format(class_weights, + class_keys)) + if class_counts is None: + raise ValueError("Class counts must be provided if using" + "class_weights=%s" % class_weights) + if num_classes is None: + raise ValueError("Class counts must be provided if using" + "class_weights=%s" % class_weights) + elif class_weights is not None: + if num_classes is None: + raise ValueError("You must pass a value for num_classes if" + "creating an array of class_weights") + # performing class weight calculation + if class_weights is None: + class_weights = 1 + elif is_string and class_weights == 'balanced': + num_samples = sum(class_counts) + class_weights = tf.Variable( + num_samples / (num_classes * class_counts), + dtype=self._dtype + ) + else: + class_weights = _ops.convert_to_tensor_v2(class_weights) + if len(class_weights.shape) != 1: + raise ValueError("Detected class_weights shape: {0} instead of " + "1D array".format(class_weights.shape)) + if class_weights.shape[0] != num_classes: + raise ValueError( + "Detected array length: {0} instead of: {1}".format( + class_weights.shape[0], + num_classes + ) + ) + return class_weights + + def _project_weights_to_r(self, r, force=False): + """helper method to normalize the weights to the R-ball. + + Args: + r: radius of "R-Ball". Scalar to normalize to. + force: True to normalize regardless of previous weight values. + False to check if weights > R-ball and only normalize then. + + Returns: + + """ + for layer in self._layers: + weight_norm = tf.norm(layer.kernel, axis=0) + if force: + layer.kernel = layer.kernel / (weight_norm / r) + elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0: + layer.kernel = layer.kernel / (weight_norm / r) + + def _get_noise(self, distribution, data_size): + """Sample noise to be added to weights for privacy guarantee + + Args: + distribution: the distribution type to pull noise from + data_size: the number of samples + + Returns: noise in shape of layer's weights to be added to the weights. + + """ + distribution = distribution.lower() + input_dim = self._layers[0].kernel.numpy().shape[0] + loss = self.loss + if distribution == 'laplace': + per_class_epsilon = self.epsilon / (self.n_classes) + l2_sensitivity = (2 * + loss.lipchitz_constant(self.class_weight)) / \ + (loss.gamma() * data_size) + unit_vector = tf.random.normal(shape=(input_dim, self.n_classes), + mean=0, + seed=1, + stddev=1.0, + dtype=self._dtype) + unit_vector = unit_vector / tf.math.sqrt( + tf.reduce_sum(tf.math.square(unit_vector), axis=0) + ) + + beta = l2_sensitivity / per_class_epsilon + alpha = input_dim # input_dim + gamma = tf.random.gamma([self.n_classes], + alpha, + beta=1 / beta, + seed=1, + dtype=self._dtype) + return unit_vector * gamma + raise NotImplementedError("distribution: {0} is not " + "currently supported".format(distribution)) diff --git a/privacy/bolton/model_test.py b/privacy/bolton/model_test.py new file mode 100644 index 0000000..87669fd --- /dev/null +++ b/privacy/bolton/model_test.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function \ No newline at end of file diff --git a/privacy/bolton/optimizer.py b/privacy/bolton/optimizer.py new file mode 100644 index 0000000..f8af390 --- /dev/null +++ b/privacy/bolton/optimizer.py @@ -0,0 +1,173 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Private Optimizer for bolton method""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.keras.optimizer_v2 import optimizer_v2 + +_private_attributes = ['_internal_optimizer', 'dtype'] + + +class Private(optimizer_v2.OptimizerV2): + """ + Private optimizer wraps another tf optimizer to be used + as the visible optimizer to the tf model. No matter the optimizer + passed, "Private" enables the bolton model to control the learning rate + based on the strongly convex loss. + """ + def __init__(self, + optimizer: optimizer_v2.OptimizerV2, + dtype=tf.float32 + ): + """Constructor. + + Args: + optimizer: Optimizer_v2 or subclass to be used as the optimizer + (wrapped). + """ + self._internal_optimizer = optimizer + self.dtype = dtype + + def get_config(self): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.get_config() + + def limit_learning_rate(self, is_eager, beta, gamma): + """Implements learning rate limitation that is required by the bolton + method for sensitivity bounding of the strongly convex function. + Sets the learning rate to the min(1/beta, 1/(gamma*t)) + + Args: + is_eager: Whether the model is running in eager mode + beta: loss function beta-smoothness + gamma: loss function gamma-strongly convex + + Returns: None + + """ + numerator = tf.Variable(initial_value=1, dtype=self.dtype) + t = tf.cast(self._iterations, self.dtype) + # will exist on the internal optimizer + pred = numerator / beta < numerator / (gamma * t) + if is_eager: # check eagerly + if pred: + self.learning_rate = numerator / beta + else: + self.learning_rate = numerator / (gamma * t) + else: + if pred: + self.learning_rate = numerator / beta + else: + self.learning_rate = numerator / (gamma * t) + + def from_config(self, config, custom_objects=None): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.from_config( + config, + custom_objects=custom_objects + ) + + def __getattr__(self, name): + """return _internal_optimizer off self instance, and everything else + from the _internal_optimizer instance. + + Args: + name: + + Returns: attribute from Private if specified to come from self, else + from _internal_optimizer. + + """ + if name in _private_attributes: + return getattr(self, name) + optim = object.__getattribute__(self, '_internal_optimizer') + return object.__getattribute__(optim, name) + + def __setattr__(self, key, value): + """ Set attribute to self instance if its the internal optimizer. + Reroute everything else to the _internal_optimizer. + + Args: + key: attribute name + value: attribute value + + Returns: + + """ + if key in _private_attributes: + object.__setattr__(self, key, value) + else: + setattr(self._internal_optimizer, key, value) + + def _resource_apply_dense(self, grad, handle): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer._resource_apply_dense(grad, handle) + + def _resource_apply_sparse(self, grad, handle, indices): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer._resource_apply_sparse( + grad, + handle, + indices + ) + + def get_updates(self, loss, params): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.get_updates(loss, params) + + def apply_gradients(self, grads_and_vars, name: str = None): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.apply_gradients( + grads_and_vars, + name=name + ) + + def minimize(self, + loss, + var_list, + grad_loss: bool = None, + name: str = None + ): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.minimize( + loss, + var_list, + grad_loss, + name + ) + + def _compute_gradients(self, loss, var_list, grad_loss=None): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer._compute_gradients( + loss, + var_list, + grad_loss=grad_loss + ) + + def get_gradients(self, loss, params): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.get_gradients(loss, params) diff --git a/privacy/bolton/optimizer_test.py b/privacy/bolton/optimizer_test.py new file mode 100644 index 0000000..ec8de48 --- /dev/null +++ b/privacy/bolton/optimizer_test.py @@ -0,0 +1,9 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.platform import test +from tensorflow.python.keras import keras_parameterized +from privacy.bolton import model + From 751eaead545d45bcc47bff7d82656b08c474b434 Mon Sep 17 00:00:00 2001 From: Christopher Choquette Choo Date: Mon, 10 Jun 2019 16:11:47 -0400 Subject: [PATCH 2/9] Working bolton model without unit tests. -- update to include pull request changes changes include: parameter renaming, changing to mixin, moving model to compile, additional tests, fixing huber loss --- privacy/bolton/__init__.py | 4 +- privacy/bolton/loss.py | 502 +++++++++++++++++++++++++------ privacy/bolton/loss_test.py | 324 +++++++++++++++++++- privacy/bolton/model.py | 89 ++++-- privacy/bolton/model_test.py | 493 +++++++++++++++++++++++++++++- privacy/bolton/optimizer.py | 56 ++-- privacy/bolton/optimizer_test.py | 173 +++++++++++ 7 files changed, 1472 insertions(+), 169 deletions(-) diff --git a/privacy/bolton/__init__.py b/privacy/bolton/__init__.py index 46bd079..67b6148 100644 --- a/privacy/bolton/__init__.py +++ b/privacy/bolton/__init__.py @@ -10,5 +10,5 @@ if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts. pass else: from privacy.bolton.model import Bolton - from privacy.bolton.loss import Huber - from privacy.bolton.loss import BinaryCrossentropy \ No newline at end of file + from privacy.bolton.loss import StrongConvexHuber + from privacy.bolton.loss import StrongConvexBinaryCrossentropy \ No newline at end of file diff --git a/privacy/bolton/loss.py b/privacy/bolton/loss.py index dd5d580..5cc029a 100644 --- a/privacy/bolton/loss.py +++ b/privacy/bolton/loss.py @@ -20,56 +20,33 @@ import tensorflow as tf from tensorflow.python.keras import losses from tensorflow.python.keras.utils import losses_utils from tensorflow.python.framework import ops as _ops +from tensorflow.python.keras.regularizers import L1L2 -class StrongConvexLoss(losses.Loss): +class StrongConvexMixin: """ - Strong Convex Loss base class for any loss function that will be used with + Strong Convex Mixin base class for any loss function that will be used with Bolton model. Subclasses must be strongly convex and implement the associated constants. They must also conform to the requirements of tf losses - (see super class) + (see super class). + + For more details on the strong convexity requirements, see: + Bolt-on Differential Privacy for Scalable Stochastic Gradient + Descent-based Analytics by Xi Wu et. al. """ - def __init__(self, - reg_lambda: float, - c: float, - radius_constant: float = 1, - reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, - name: str = None, - dtype=tf.float32, - **kwargs): - """ - Args: - reg_lambda: Weight regularization constant - c: Additional constant for strongly convex convergence. Acts - as a global weight. - radius_constant: constant defining the length of the radius - reduction: reduction type to use. See super class - name: Name of the loss instance - dtype: tf datatype to use for tensor conversions. - """ - super(StrongConvexLoss, self).__init__(reduction=reduction, - name=name, - **kwargs) - self._sample_weight = tf.Variable(initial_value=c, - trainable=False, - dtype=tf.float32) - self._reg_lambda = reg_lambda - self.radius_constant = tf.Variable(initial_value=radius_constant, - trainable=False, - dtype=tf.float32) - self.dtype = dtype def radius(self): - """Radius of R-Ball (value to normalize weights to after each batch) + """Radius, R, of the hypothesis space W. + W is a convex set that forms the hypothesis space. - Returns: radius + Returns: R """ raise NotImplementedError("Radius not implemented for StrongConvex Loss" "function: %s" % str(self.__class__.__name__)) def gamma(self): - """ Gamma strongly convex + """ Strongly convexity, gamma Returns: gamma @@ -78,7 +55,7 @@ class StrongConvexLoss(losses.Loss): "function: %s" % str(self.__class__.__name__)) def beta(self, class_weight): - """Beta smoothess + """Smoothness, beta Args: class_weight: the class weights used. @@ -90,7 +67,7 @@ class StrongConvexLoss(losses.Loss): "function: %s" % str(self.__class__.__name__)) def lipchitz_constant(self, class_weight): - """ L lipchitz continuous + """Lipchitz constant, L Args: class_weight: class weights used @@ -102,43 +79,46 @@ class StrongConvexLoss(losses.Loss): "StrongConvex Loss" "function: %s" % str(self.__class__.__name__)) - def reg_lambda(self, convert_to_tensor: bool = False): - """ returns the lambda weight regularization constant, as a tensor if - desired + def kernel_regularizer(self): + """returns the kernel_regularizer to be used. Any subclass should override + this method if they want a kernel_regularizer (if required for + the loss function to be StronglyConvex + + :return: None or kernel_regularizer layer + """ + return None + + def max_class_weight(self, class_weight, dtype): + """the maximum weighting in class weights (max value) as a scalar tensor Args: - convert_to_tensor: True to convert to tensor, False to leave as - python numeric. + class_weight: class weights used + dtype: the data type for tensor conversions. - Returns: reg_lambda + Returns: maximum class weighting as tensor scalar """ - if convert_to_tensor: - return _ops.convert_to_tensor_v2(self._reg_lambda, dtype=self.dtype) - return self._reg_lambda - - def max_class_weight(self, class_weight): - class_weight = _ops.convert_to_tensor_v2(class_weight, dtype=self.dtype) + class_weight = _ops.convert_to_tensor_v2(class_weight, dtype) return tf.math.reduce_max(class_weight) -class Huber(StrongConvexLoss, losses.Huber): - """Strong Convex version of huber loss using l2 weight regularization. +class StrongConvexHuber(losses.Huber, StrongConvexMixin): + """Strong Convex version of Huber loss using l2 weight regularization. """ + def __init__(self, reg_lambda: float, - c: float, + C: float, radius_constant: float, delta: float, reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, name: str = 'huber', dtype=tf.float32): - """Constructor. Passes arguments to StrongConvexLoss and Huber Loss. + """Constructor. Args: reg_lambda: Weight regularization constant - c: Additional constant for strongly convex convergence. Acts - as a global weight. + C: Penalty parameter C of the loss term radius_constant: constant defining the length of the radius delta: delta value in huber loss. When to switch from quadratic to absolute deviation. @@ -149,15 +129,22 @@ class Huber(StrongConvexLoss, losses.Huber): Returns: Loss values per sample. """ - # self.delta = tf.Variable(initial_value=delta, trainable=False) - super(Huber, self).__init__( - reg_lambda, - c, - radius_constant, + if C <= 0: + raise ValueError('c: {0}, should be >= 0'.format(C)) + if reg_lambda <= 0: + raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) + if radius_constant <= 0: + raise ValueError('radius_constant: {0}, should be >= 0'.format( + radius_constant + )) + self.C = C + self.radius_constant = radius_constant + self.dtype = dtype + self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) + super(StrongConvexHuber, self).__init__( delta=delta, name=name, reduction=reduction, - dtype=dtype ) def call(self, y_true, y_pred): @@ -170,46 +157,73 @@ class Huber(StrongConvexLoss, losses.Huber): Returns: Loss values per sample. """ - return super(Huber, self).call(y_true, y_pred, **self._fn_kwargs) * \ - self._sample_weight + # return super(StrongConvexHuber, self).call(y_true, y_pred) * self._sample_weight + h = self._fn_kwargs['delta'] + z = y_pred * y_true + one = tf.constant(1, dtype=self.dtype) + four = tf.constant(4, dtype=self.dtype) + + if z > one + h: + return z - z + elif tf.math.abs(one - z) <= h: + return one / (four * h) * tf.math.pow(one + h - z, 2) + elif z < one - h: + return one - z + else: + raise ValueError('') def radius(self): """See super class. """ - return self.radius_constant / self.reg_lambda(True) + return self.radius_constant / self.reg_lambda def gamma(self): """See super class. """ - return self.reg_lambda(True) + return self.reg_lambda def beta(self, class_weight): """See super class. """ - max_class_weight = self.max_class_weight(class_weight) - return self._sample_weight * max_class_weight / \ - (self.delta * tf.Variable(initial_value=2, trainable=False)) + \ - self.reg_lambda(True) + max_class_weight = self.max_class_weight(class_weight, self.dtype) + delta = _ops.convert_to_tensor_v2(self._fn_kwargs['delta'], + dtype=self.dtype + ) + return self.C * max_class_weight / (delta * + tf.constant(2, dtype=self.dtype)) + \ + self.reg_lambda def lipchitz_constant(self, class_weight): """See super class. """ # if class_weight is provided, # it should be a vector of the same size of number of classes - max_class_weight = self.max_class_weight(class_weight) - lc = self._sample_weight * max_class_weight + \ - self.reg_lambda(True) * self.radius() + max_class_weight = self.max_class_weight(class_weight, self.dtype) + lc = self.C * max_class_weight + \ + self.reg_lambda * self.radius() return lc + def kernel_regularizer(self): + """ + l2 loss using reg_lambda as the l2 term (as desired). Required for + this loss function to be strongly convex. + :return: + """ + return L1L2(l2=self.reg_lambda) -class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy): + +class StrongConvexBinaryCrossentropy( + losses.BinaryCrossentropy, + StrongConvexMixin +): """ Strong Convex version of BinaryCrossentropy loss using l2 weight regularization. """ + def __init__(self, reg_lambda: float, - c: float, + C: float, radius_constant: float, from_logits: bool = True, label_smoothing: float = 0, @@ -219,8 +233,7 @@ class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy): """ Args: reg_lambda: Weight regularization constant - c: Additional constant for strongly convex convergence. Acts - as a global weight. + C: Penalty parameter C of the loss term radius_constant: constant defining the length of the radius reduction: reduction type to use. See super class label_smoothing: amount of smoothing to perform on labels @@ -228,15 +241,23 @@ class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy): name: Name of the loss instance dtype: tf datatype to use for tensor conversions. """ - super(BinaryCrossentropy, self).__init__(reg_lambda, - c, - radius_constant, - reduction=reduction, - name=name, - from_logits=from_logits, - label_smoothing=label_smoothing, - dtype=dtype - ) + if reg_lambda <= 0: + raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) + if C <= 0: + raise ValueError('c: {0}, should be >= 0'.format(C)) + if radius_constant <= 0: + raise ValueError('radius_constant: {0}, should be >= 0'.format( + radius_constant + )) + self.dtype = dtype + self.C = C + self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) + super(StrongConvexBinaryCrossentropy, self).__init__( + reduction=reduction, + name=name, + from_logits=from_logits, + label_smoothing=label_smoothing, + ) self.radius_constant = radius_constant def call(self, y_true, y_pred): @@ -249,32 +270,319 @@ class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy): Returns: Loss values per sample. """ - loss = tf.nn.sigmoid_cross_entropy_with_logits( - labels=y_true, - logits=y_pred - ) - loss = loss * self._sample_weight + # loss = tf.nn.sigmoid_cross_entropy_with_logits( + # labels=y_true, + # logits=y_pred + # ) + loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred) + loss = loss * self.C return loss def radius(self): """See super class. """ - return self.radius_constant / self.reg_lambda(True) + return self.radius_constant / self.reg_lambda def gamma(self): """See super class. """ - return self.reg_lambda(True) + return self.reg_lambda def beta(self, class_weight): """See super class. """ - max_class_weight = self.max_class_weight(class_weight) - return self._sample_weight * max_class_weight + self.reg_lambda(True) + max_class_weight = self.max_class_weight(class_weight, self.dtype) + return self.C * max_class_weight + self.reg_lambda def lipchitz_constant(self, class_weight): """See super class. """ - max_class_weight = self.max_class_weight(class_weight) - return self._sample_weight * max_class_weight + \ - self.reg_lambda(True) * self.radius() + max_class_weight = self.max_class_weight(class_weight, self.dtype) + return self.C * max_class_weight + self.reg_lambda * self.radius() + + def kernel_regularizer(self): + """ + l2 loss using reg_lambda as the l2 term (as desired). Required for + this loss function to be strongly convex. + :return: + """ + return L1L2(l2=self.reg_lambda) + + +# class StrongConvexSparseCategoricalCrossentropy( +# losses.CategoricalCrossentropy, +# StrongConvexMixin +# ): +# """ +# Strong Convex version of CategoricalCrossentropy loss using l2 weight +# regularization. +# """ +# +# def __init__(self, +# reg_lambda: float, +# C: float, +# radius_constant: float, +# from_logits: bool = True, +# label_smoothing: float = 0, +# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, +# name: str = 'binarycrossentropy', +# dtype=tf.float32): +# """ +# Args: +# reg_lambda: Weight regularization constant +# C: Penalty parameter C of the loss term +# radius_constant: constant defining the length of the radius +# reduction: reduction type to use. See super class +# label_smoothing: amount of smoothing to perform on labels +# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) +# name: Name of the loss instance +# dtype: tf datatype to use for tensor conversions. +# """ +# if reg_lambda <= 0: +# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) +# if C <= 0: +# raise ValueError('c: {0}, should be >= 0'.format(C)) +# if radius_constant <= 0: +# raise ValueError('radius_constant: {0}, should be >= 0'.format( +# radius_constant +# )) +# +# self.C = C +# self.dtype = dtype +# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) +# super(StrongConvexSparseCategoricalCrossentropy, self).__init__( +# reduction=reduction, +# name=name, +# from_logits=from_logits, +# label_smoothing=label_smoothing, +# ) +# self.radius_constant = radius_constant +# +# def call(self, y_true, y_pred): +# """Compute loss +# +# Args: +# y_true: Ground truth values. +# y_pred: The predicted values. +# +# Returns: +# Loss values per sample. +# """ +# loss = super() +# loss = loss * self.C +# return loss +# +# def radius(self): +# """See super class. +# """ +# return self.radius_constant / self.reg_lambda +# +# def gamma(self): +# """See super class. +# """ +# return self.reg_lambda +# +# def beta(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda +# +# def lipchitz_constant(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda * self.radius() +# +# def kernel_regularizer(self): +# """ +# l2 loss using reg_lambda as the l2 term (as desired). Required for +# this loss function to be strongly convex. +# :return: +# """ +# return L1L2(l2=self.reg_lambda) +# +# class StrongConvexSparseCategoricalCrossentropy( +# losses.SparseCategoricalCrossentropy, +# StrongConvexMixin +# ): +# """ +# Strong Convex version of SparseCategoricalCrossentropy loss using l2 weight +# regularization. +# """ +# +# def __init__(self, +# reg_lambda: float, +# C: float, +# radius_constant: float, +# from_logits: bool = True, +# label_smoothing: float = 0, +# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, +# name: str = 'binarycrossentropy', +# dtype=tf.float32): +# """ +# Args: +# reg_lambda: Weight regularization constant +# C: Penalty parameter C of the loss term +# radius_constant: constant defining the length of the radius +# reduction: reduction type to use. See super class +# label_smoothing: amount of smoothing to perform on labels +# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) +# name: Name of the loss instance +# dtype: tf datatype to use for tensor conversions. +# """ +# if reg_lambda <= 0: +# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) +# if C <= 0: +# raise ValueError('c: {0}, should be >= 0'.format(C)) +# if radius_constant <= 0: +# raise ValueError('radius_constant: {0}, should be >= 0'.format( +# radius_constant +# )) +# +# self.C = C +# self.dtype = dtype +# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) +# super(StrongConvexHuber, self).__init__(reduction=reduction, +# name=name, +# from_logits=from_logits, +# label_smoothing=label_smoothing, +# ) +# self.radius_constant = radius_constant +# +# def call(self, y_true, y_pred): +# """Compute loss +# +# Args: +# y_true: Ground truth values. +# y_pred: The predicted values. +# +# Returns: +# Loss values per sample. +# """ +# loss = super() +# loss = loss * self.C +# return loss +# +# def radius(self): +# """See super class. +# """ +# return self.radius_constant / self.reg_lambda +# +# def gamma(self): +# """See super class. +# """ +# return self.reg_lambda +# +# def beta(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda +# +# def lipchitz_constant(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda * self.radius() +# +# def kernel_regularizer(self): +# """ +# l2 loss using reg_lambda as the l2 term (as desired). Required for +# this loss function to be strongly convex. +# :return: +# """ +# return L1L2(l2=self.reg_lambda) +# +# +# class StrongConvexCategoricalCrossentropy( +# losses.CategoricalCrossentropy, +# StrongConvexMixin +# ): +# """ +# Strong Convex version of CategoricalCrossentropy loss using l2 weight +# regularization. +# """ +# +# def __init__(self, +# reg_lambda: float, +# C: float, +# radius_constant: float, +# from_logits: bool = True, +# label_smoothing: float = 0, +# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, +# name: str = 'binarycrossentropy', +# dtype=tf.float32): +# """ +# Args: +# reg_lambda: Weight regularization constant +# C: Penalty parameter C of the loss term +# radius_constant: constant defining the length of the radius +# reduction: reduction type to use. See super class +# label_smoothing: amount of smoothing to perform on labels +# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) +# name: Name of the loss instance +# dtype: tf datatype to use for tensor conversions. +# """ +# if reg_lambda <= 0: +# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) +# if C <= 0: +# raise ValueError('c: {0}, should be >= 0'.format(C)) +# if radius_constant <= 0: +# raise ValueError('radius_constant: {0}, should be >= 0'.format( +# radius_constant +# )) +# +# self.C = C +# self.dtype = dtype +# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) +# super(StrongConvexHuber, self).__init__(reduction=reduction, +# name=name, +# from_logits=from_logits, +# label_smoothing=label_smoothing, +# ) +# self.radius_constant = radius_constant +# +# def call(self, y_true, y_pred): +# """Compute loss +# +# Args: +# y_true: Ground truth values. +# y_pred: The predicted values. +# +# Returns: +# Loss values per sample. +# """ +# loss = super() +# loss = loss * self.C +# return loss +# +# def radius(self): +# """See super class. +# """ +# return self.radius_constant / self.reg_lambda +# +# def gamma(self): +# """See super class. +# """ +# return self.reg_lambda +# +# def beta(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda +# +# def lipchitz_constant(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda * self.radius() +# +# def kernel_regularizer(self): +# """ +# l2 loss using reg_lambda as the l2 term (as desired). Required for +# this loss function to be strongly convex. +# :return: +# """ +# return L1L2(l2=self.reg_lambda) diff --git a/privacy/bolton/loss_test.py b/privacy/bolton/loss_test.py index 87669fd..bb7dc53 100644 --- a/privacy/bolton/loss_test.py +++ b/privacy/bolton/loss_test.py @@ -1,3 +1,325 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit testing for loss.py""" + from __future__ import absolute_import from __future__ import division -from __future__ import print_function \ No newline at end of file +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.platform import test +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras.optimizer_v2 import adam +from tensorflow.python.keras.optimizer_v2 import adagrad +from tensorflow.python.keras.optimizer_v2 import gradient_descent +from tensorflow.python.keras import losses +from tensorflow.python.framework import test_util +from privacy.bolton import model +from privacy.bolton.loss import StrongConvexBinaryCrossentropy +from privacy.bolton.loss import StrongConvexHuber +from privacy.bolton.loss import StrongConvexMixin +from absl.testing import parameterized +from absl.testing import absltest +from tensorflow.python.keras.regularizers import L1L2 + + +class StrongConvexTests(keras_parameterized.TestCase): + @parameterized.named_parameters([ + {'testcase_name': 'beta not implemented', + 'fn': 'beta', + 'args': [1]}, + {'testcase_name': 'gamma not implemented', + 'fn': 'gamma', + 'args': []}, + {'testcase_name': 'lipchitz not implemented', + 'fn': 'lipchitz_constant', + 'args': [1]}, + {'testcase_name': 'radius not implemented', + 'fn': 'radius', + 'args': []}, + ]) + def test_not_implemented(self, fn, args): + with self.assertRaises(NotImplementedError): + loss = StrongConvexMixin() + getattr(loss, fn, None)(*args) + + @parameterized.named_parameters([ + {'testcase_name': 'radius not implemented', + 'fn': 'kernel_regularizer', + 'args': []}, + ]) + def test_return_none(self, fn, args): + loss = StrongConvexMixin() + ret = getattr(loss, fn, None)(*args) + self.assertEqual(ret, None) + + +class BinaryCrossesntropyTests(keras_parameterized.TestCase): + """tests for BinaryCrossesntropy StrongConvex loss""" + + @parameterized.named_parameters([ + {'testcase_name': 'normal', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': 1 + }, + ]) + def test_init_params(self, reg_lambda, c, radius_constant): + # test valid domains for each variable + loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant) + self.assertIsInstance(loss, StrongConvexBinaryCrossentropy) + + @parameterized.named_parameters([ + {'testcase_name': 'negative c', + 'reg_lambda': 1, + 'c': -1, + 'radius_constant': 1 + }, + {'testcase_name': 'negative radius', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': -1 + }, + {'testcase_name': 'negative lambda', + 'reg_lambda': -1, + 'c': 1, + 'radius_constant': 1 + }, + ]) + def test_bad_init_params(self, reg_lambda, c, radius_constant): + # test valid domains for each variable + with self.assertRaises(ValueError): + loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant) + + @test_util.run_all_in_graph_and_eager_modes + @parameterized.named_parameters([ + # [] for compatibility with tensorflow loss calculation + {'testcase_name': 'both positive', + 'logits': [10000], + 'y_true': [1], + 'result': 0, + }, + {'testcase_name': 'positive gradient negative logits', + 'logits': [-10000], + 'y_true': [1], + 'result': 10000, + }, + {'testcase_name': 'positivee gradient positive logits', + 'logits': [10000], + 'y_true': [0], + 'result': 10000, + }, + {'testcase_name': 'both negative', + 'logits': [-10000], + 'y_true': [0], + 'result': 0 + }, + ]) + def test_calculation(self, logits, y_true, result): + logits = tf.Variable(logits, False, dtype=tf.float32) + y_true = tf.Variable(y_true, False, dtype=tf.float32) + loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1) + loss = loss(y_true, logits) + self.assertEqual(loss.numpy(), result) + + @parameterized.named_parameters([ + {'testcase_name': 'beta', + 'init_args': [1, 1, 1], + 'fn': 'beta', + 'args': [1], + 'result': tf.constant(2, dtype=tf.float32) + }, + {'testcase_name': 'gamma', + 'fn': 'gamma', + 'init_args': [1, 1, 1], + 'args': [], + 'result': tf.constant(1, dtype=tf.float32), + }, + {'testcase_name': 'lipchitz constant', + 'fn': 'lipchitz_constant', + 'init_args': [1, 1, 1], + 'args': [1], + 'result': tf.constant(2, dtype=tf.float32), + }, + {'testcase_name': 'kernel regularizer', + 'fn': 'kernel_regularizer', + 'init_args': [1, 1, 1], + 'args': [], + 'result': L1L2(l2=1), + }, + ]) + def test_fns(self, init_args, fn, args, result): + loss = StrongConvexBinaryCrossentropy(*init_args) + expected = getattr(loss, fn, lambda: 'fn not found')(*args) + if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor + expected = expected.numpy() + result = result.numpy() + if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer + expected = expected.l2 + result = result.l2 + self.assertEqual(expected, result) + + +class HuberTests(keras_parameterized.TestCase): + """tests for BinaryCrossesntropy StrongConvex loss""" + + @parameterized.named_parameters([ + {'testcase_name': 'normal', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': 1, + 'delta': 1, + }, + ]) + def test_init_params(self, reg_lambda, c, radius_constant, delta): + # test valid domains for each variable + loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta) + self.assertIsInstance(loss, StrongConvexHuber) + + @parameterized.named_parameters([ + {'testcase_name': 'negative c', + 'reg_lambda': 1, + 'c': -1, + 'radius_constant': 1, + 'delta': 1 + }, + {'testcase_name': 'negative radius', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': -1, + 'delta': 1 + }, + {'testcase_name': 'negative lambda', + 'reg_lambda': -1, + 'c': 1, + 'radius_constant': 1, + 'delta': 1 + }, + {'testcase_name': 'negative delta', + 'reg_lambda': -1, + 'c': 1, + 'radius_constant': 1, + 'delta': -1 + }, + ]) + def test_bad_init_params(self, reg_lambda, c, radius_constant, delta): + # test valid domains for each variable + with self.assertRaises(ValueError): + loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta) + + # test the bounds and test varied delta's + @test_util.run_all_in_graph_and_eager_modes + @parameterized.named_parameters([ + {'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary', + 'logits': 2.1, + 'y_true': 1, + 'delta': 1, + 'result': 0, + }, + {'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary', + 'logits': 1.9, + 'y_true': 1, + 'delta': 1, + 'result': 0.01*0.25, + }, + {'testcase_name': 'delta=1,y_true=1 1-z< h decision boundary', + 'logits': 0.1, + 'y_true': 1, + 'delta': 1, + 'result': 1.9**2 * 0.25, + }, + {'testcase_name': 'delta=1,y_true=1 z < 1-h decision boundary', + 'logits': -0.1, + 'y_true': 1, + 'delta': 1, + 'result': 1.1, + }, + {'testcase_name': 'delta=2,y_true=1 z>1+h decision boundary', + 'logits': 3.1, + 'y_true': 1, + 'delta': 2, + 'result': 0, + }, + {'testcase_name': 'delta=2,y_true=1 z<1+h decision boundary', + 'logits': 2.9, + 'y_true': 1, + 'delta': 2, + 'result': 0.01*0.125, + }, + {'testcase_name': 'delta=2,y_true=1 1-z < h decision boundary', + 'logits': 1.1, + 'y_true': 1, + 'delta': 2, + 'result': 1.9**2 * 0.125, + }, + {'testcase_name': 'delta=2,y_true=1 z < 1-h decision boundary', + 'logits': -1.1, + 'y_true': 1, + 'delta': 2, + 'result': 2.1, + }, + {'testcase_name': 'delta=1,y_true=-1 z>1+h decision boundary', + 'logits': -2.1, + 'y_true': -1, + 'delta': 1, + 'result': 0, + }, + ]) + def test_calculation(self, logits, y_true, delta, result): + logits = tf.Variable(logits, False, dtype=tf.float32) + y_true = tf.Variable(y_true, False, dtype=tf.float32) + loss = StrongConvexHuber(0.00001, 1, 1, delta) + loss = loss(y_true, logits) + self.assertAllClose(loss.numpy(), result) + + @parameterized.named_parameters([ + {'testcase_name': 'beta', + 'init_args': [1, 1, 1, 1], + 'fn': 'beta', + 'args': [1], + 'result': tf.Variable(1.5, dtype=tf.float32) + }, + {'testcase_name': 'gamma', + 'fn': 'gamma', + 'init_args': [1, 1, 1, 1], + 'args': [], + 'result': tf.Variable(1, dtype=tf.float32), + }, + {'testcase_name': 'lipchitz constant', + 'fn': 'lipchitz_constant', + 'init_args': [1, 1, 1, 1], + 'args': [1], + 'result': tf.Variable(2, dtype=tf.float32), + }, + {'testcase_name': 'kernel regularizer', + 'fn': 'kernel_regularizer', + 'init_args': [1, 1, 1, 1], + 'args': [], + 'result': L1L2(l2=1), + }, + ]) + def test_fns(self, init_args, fn, args, result): + loss = StrongConvexHuber(*init_args) + expected = getattr(loss, fn, lambda: 'fn not found')(*args) + if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor + expected = expected.numpy() + result = result.numpy() + if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer + expected = expected.l2 + result = result.l2 + self.assertEqual(expected, result) + + +if __name__ == '__main__': + tf.test.main() \ No newline at end of file diff --git a/privacy/bolton/model.py b/privacy/bolton/model.py index a600374..78ceb7c 100644 --- a/privacy/bolton/model.py +++ b/privacy/bolton/model.py @@ -19,11 +19,12 @@ from __future__ import print_function import tensorflow as tf from tensorflow.python.keras.models import Model from tensorflow.python.keras import optimizers -from tensorflow.python.training.tracking import base as trackable from tensorflow.python.framework import ops as _ops -from privacy.bolton.loss import StrongConvexLoss +from privacy.bolton.loss import StrongConvexMixin from privacy.bolton.optimizer import Private +_accepted_distributions = ['laplace'] + class Bolton(Model): """ @@ -33,12 +34,16 @@ class Bolton(Model): 2. Projects weights to R after each batch 3. Limits learning rate 4. Use a strongly convex loss function (see compile) + + For more details on the strong convexity requirements, see: + Bolt-on Differential Privacy for Scalable Stochastic Gradient + Descent-based Analytics by Xi Wu et. al. """ + def __init__(self, n_classes, epsilon, noise_distribution='laplace', - weights_initializer=tf.initializers.GlorotUniform(), seed=1, dtype=tf.float32 ): @@ -59,6 +64,7 @@ class Bolton(Model): 2. Projects weights to R after each batch 3. Limits learning rate """ + def on_train_batch_end(self, batch, logs=None): loss = self.model.loss self.model.optimizer.limit_learning_rate( @@ -72,13 +78,17 @@ class Bolton(Model): loss = self.model.loss self.model._project_weights_to_r(loss.radius(), True) + if epsilon <= 0: + raise ValueError('Detected epsilon: {0}. ' + 'Valid range is 0 < epsilon Date: Thu, 13 Jun 2019 01:01:31 -0400 Subject: [PATCH 3/9] Working bolton model without unit tests. -- moving to Bolton Optimizer Model is now just a convenient wrapper and example for users. Optimizer holds ALL Bolton privacy requirements. Optimizer is used as a context manager, and must be passed the model's layers. Unit tests incomplete, committing for visibility into the design. --- privacy/bolton/loss_test.py | 110 ++++++-- privacy/bolton/model.py | 448 ++++++++++++++++++++----------- privacy/bolton/model_test.py | 190 +++++++------ privacy/bolton/optimizer.py | 254 ++++++++++++++++-- privacy/bolton/optimizer_test.py | 256 ++++++++++++++---- 5 files changed, 913 insertions(+), 345 deletions(-) diff --git a/privacy/bolton/loss_test.py b/privacy/bolton/loss_test.py index bb7dc53..ddb4861 100644 --- a/privacy/bolton/loss_test.py +++ b/privacy/bolton/loss_test.py @@ -18,23 +18,17 @@ from __future__ import division from __future__ import print_function import tensorflow as tf -from tensorflow.python.platform import test from tensorflow.python.keras import keras_parameterized -from tensorflow.python.keras.optimizer_v2 import adam -from tensorflow.python.keras.optimizer_v2 import adagrad -from tensorflow.python.keras.optimizer_v2 import gradient_descent -from tensorflow.python.keras import losses from tensorflow.python.framework import test_util -from privacy.bolton import model +from tensorflow.python.keras.regularizers import L1L2 +from absl.testing import parameterized from privacy.bolton.loss import StrongConvexBinaryCrossentropy from privacy.bolton.loss import StrongConvexHuber from privacy.bolton.loss import StrongConvexMixin -from absl.testing import parameterized -from absl.testing import absltest -from tensorflow.python.keras.regularizers import L1L2 -class StrongConvexTests(keras_parameterized.TestCase): +class StrongConvexMixinTests(keras_parameterized.TestCase): + """Tests for the StrongConvexMixin""" @parameterized.named_parameters([ {'testcase_name': 'beta not implemented', 'fn': 'beta', @@ -50,6 +44,12 @@ class StrongConvexTests(keras_parameterized.TestCase): 'args': []}, ]) def test_not_implemented(self, fn, args): + """Test that the given fn's are not implemented on the mixin. + + Args: + fn: fn on Mixin to test + args: arguments to fn of Mixin + """ with self.assertRaises(NotImplementedError): loss = StrongConvexMixin() getattr(loss, fn, None)(*args) @@ -60,6 +60,12 @@ class StrongConvexTests(keras_parameterized.TestCase): 'args': []}, ]) def test_return_none(self, fn, args): + """Test that fn of Mixin returns None + + Args: + fn: fn of Mixin to test + args: arguments to fn of Mixin + """ loss = StrongConvexMixin() ret = getattr(loss, fn, None)(*args) self.assertEqual(ret, None) @@ -71,44 +77,56 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase): @parameterized.named_parameters([ {'testcase_name': 'normal', 'reg_lambda': 1, - 'c': 1, + 'C': 1, 'radius_constant': 1 }, ]) - def test_init_params(self, reg_lambda, c, radius_constant): + def test_init_params(self, reg_lambda, C, radius_constant): + """Test initialization for given arguments + Args: + reg_lambda: initialization value for reg_lambda arg + C: initialization value for C arg + radius_constant: initialization value for radius_constant arg + """ # test valid domains for each variable - loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant) + loss = StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant) self.assertIsInstance(loss, StrongConvexBinaryCrossentropy) @parameterized.named_parameters([ {'testcase_name': 'negative c', 'reg_lambda': 1, - 'c': -1, + 'C': -1, 'radius_constant': 1 }, {'testcase_name': 'negative radius', 'reg_lambda': 1, - 'c': 1, + 'C': 1, 'radius_constant': -1 }, {'testcase_name': 'negative lambda', 'reg_lambda': -1, - 'c': 1, + 'C': 1, 'radius_constant': 1 }, ]) - def test_bad_init_params(self, reg_lambda, c, radius_constant): + def test_bad_init_params(self, reg_lambda, C, radius_constant): + """Test invalid domain for given params. Should return ValueError + Args: + reg_lambda: initialization value for reg_lambda arg + C: initialization value for C arg + radius_constant: initialization value for radius_constant arg + """ # test valid domains for each variable with self.assertRaises(ValueError): - loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant) + StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant) @test_util.run_all_in_graph_and_eager_modes @parameterized.named_parameters([ # [] for compatibility with tensorflow loss calculation {'testcase_name': 'both positive', - 'logits': [10000], - 'y_true': [1], - 'result': 0, + 'logits': [10000], + 'y_true': [1], + 'result': 0, }, {'testcase_name': 'positive gradient negative logits', 'logits': [-10000], @@ -127,6 +145,12 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase): }, ]) def test_calculation(self, logits, y_true, result): + """Test the call method to ensure it returns the correct value + Args: + logits: unscaled output of model + y_true: label + result: correct loss calculation value + """ logits = tf.Variable(logits, False, dtype=tf.float32) y_true = tf.Variable(y_true, False, dtype=tf.float32) loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1) @@ -160,6 +184,13 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase): }, ]) def test_fns(self, init_args, fn, args, result): + """Test that fn of BinaryCrossentropy loss returns the correct result + Args: + init_args: init values for loss instance + fn: the fn to test + args: the arguments to above function + result: the correct result from the fn + """ loss = StrongConvexBinaryCrossentropy(*init_args) expected = getattr(loss, fn, lambda: 'fn not found')(*args) if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor @@ -183,6 +214,12 @@ class HuberTests(keras_parameterized.TestCase): }, ]) def test_init_params(self, reg_lambda, c, radius_constant, delta): + """Test initialization for given arguments + Args: + reg_lambda: initialization value for reg_lambda arg + C: initialization value for C arg + radius_constant: initialization value for radius_constant arg + """ # test valid domains for each variable loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta) self.assertIsInstance(loss, StrongConvexHuber) @@ -214,18 +251,24 @@ class HuberTests(keras_parameterized.TestCase): }, ]) def test_bad_init_params(self, reg_lambda, c, radius_constant, delta): + """Test invalid domain for given params. Should return ValueError + Args: + reg_lambda: initialization value for reg_lambda arg + C: initialization value for C arg + radius_constant: initialization value for radius_constant arg + """ # test valid domains for each variable with self.assertRaises(ValueError): - loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta) + StrongConvexHuber(reg_lambda, c, radius_constant, delta) # test the bounds and test varied delta's @test_util.run_all_in_graph_and_eager_modes @parameterized.named_parameters([ {'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary', - 'logits': 2.1, - 'y_true': 1, - 'delta': 1, - 'result': 0, + 'logits': 2.1, + 'y_true': 1, + 'delta': 1, + 'result': 0, }, {'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary', 'logits': 1.9, @@ -277,6 +320,12 @@ class HuberTests(keras_parameterized.TestCase): }, ]) def test_calculation(self, logits, y_true, delta, result): + """Test the call method to ensure it returns the correct value + Args: + logits: unscaled output of model + y_true: label + result: correct loss calculation value + """ logits = tf.Variable(logits, False, dtype=tf.float32) y_true = tf.Variable(y_true, False, dtype=tf.float32) loss = StrongConvexHuber(0.00001, 1, 1, delta) @@ -310,6 +359,13 @@ class HuberTests(keras_parameterized.TestCase): }, ]) def test_fns(self, init_args, fn, args, result): + """Test that fn of BinaryCrossentropy loss returns the correct result + Args: + init_args: init values for loss instance + fn: the fn to test + args: the arguments to above function + result: the correct result from the fn + """ loss = StrongConvexHuber(*init_args) expected = getattr(loss, fn, lambda: 'fn not found')(*args) if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor @@ -322,4 +378,4 @@ class HuberTests(keras_parameterized.TestCase): if __name__ == '__main__': - tf.test.main() \ No newline at end of file + tf.test.main() diff --git a/privacy/bolton/model.py b/privacy/bolton/model.py index 78ceb7c..a6731fe 100644 --- a/privacy/bolton/model.py +++ b/privacy/bolton/model.py @@ -21,12 +21,12 @@ from tensorflow.python.keras.models import Model from tensorflow.python.keras import optimizers from tensorflow.python.framework import ops as _ops from privacy.bolton.loss import StrongConvexMixin -from privacy.bolton.optimizer import Private +from privacy.bolton.optimizer import Bolton _accepted_distributions = ['laplace'] -class Bolton(Model): +class BoltonModel(Model): """ Bolton episilon-delta model Uses 4 key steps to achieve privacy guarantees: @@ -42,8 +42,7 @@ class Bolton(Model): def __init__(self, n_classes, - epsilon, - noise_distribution='laplace', + # noise_distribution='laplace', seed=1, dtype=tf.float32 ): @@ -58,47 +57,22 @@ class Bolton(Model): dtype: data type to use for tensors """ - class MyCustomCallback(tf.keras.callbacks.Callback): - """Custom callback for bolton training requirements. - Implements steps (see Bolton class): - 2. Projects weights to R after each batch - 3. Limits learning rate - """ - - def on_train_batch_end(self, batch, logs=None): - loss = self.model.loss - self.model.optimizer.limit_learning_rate( - self.model.run_eagerly, - loss.beta(self.model.class_weight), - loss.gamma() - ) - self.model._project_weights_to_r(loss.radius(), False) - - def on_train_end(self, logs=None): - loss = self.model.loss - self.model._project_weights_to_r(loss.radius(), True) - - if epsilon <= 0: - raise ValueError('Detected epsilon: {0}. ' - 'Valid range is 0 < epsilon R-ball and only normalize then. + # + # Returns: + # + # """ + # for layer in self.layers: + # weight_norm = tf.norm(layer.kernel, axis=0) + # if force: + # layer.kernel = layer.kernel / (weight_norm / r) + # elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0: + # layer.kernel = layer.kernel / (weight_norm / r) - Args: - r: radius of "R-Ball". Scalar to normalize to. - force: True to normalize regardless of previous weight values. - False to check if weights > R-ball and only normalize then. + # def _get_noise(self, distribution, data_size): + # """Sample noise to be added to weights for privacy guarantee + # + # Args: + # distribution: the distribution type to pull noise from + # data_size: the number of samples + # + # Returns: noise in shape of layer's weights to be added to the weights. + # + # """ + # distribution = distribution.lower() + # input_dim = self.layers[0].kernel.numpy().shape[0] + # loss = self.loss + # if distribution == _accepted_distributions[0]: # laplace + # per_class_epsilon = self.epsilon / (self.n_classes) + # l2_sensitivity = (2 * + # loss.lipchitz_constant(self.class_weight)) / \ + # (loss.gamma() * data_size) + # unit_vector = tf.random.normal(shape=(input_dim, self.n_classes), + # mean=0, + # seed=1, + # stddev=1.0, + # dtype=self._dtype) + # unit_vector = unit_vector / tf.math.sqrt( + # tf.reduce_sum(tf.math.square(unit_vector), axis=0) + # ) + # + # beta = l2_sensitivity / per_class_epsilon + # alpha = input_dim # input_dim + # gamma = tf.random.gamma([self.n_classes], + # alpha, + # beta=1 / beta, + # seed=1, + # dtype=self._dtype + # ) + # return unit_vector * gamma + # raise NotImplementedError('Noise distribution: {0} is not ' + # 'a valid distribution'.format(distribution)) - Returns: - """ - for layer in self._layers: - weight_norm = tf.norm(layer.kernel, axis=0) - if force: - layer.kernel = layer.kernel / (weight_norm / r) - elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0: - layer.kernel = layer.kernel / (weight_norm / r) +if __name__ == '__main__': + import tensorflow as tf - def _get_noise(self, distribution, data_size): - """Sample noise to be added to weights for privacy guarantee + import os + import time + import matplotlib.pyplot as plt - Args: - distribution: the distribution type to pull noise from - data_size: the number of samples + _URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz' - Returns: noise in shape of layer's weights to be added to the weights. + path_to_zip = tf.keras.utils.get_file('facades.tar.gz', + origin=_URL, + extract=True) - """ - distribution = distribution.lower() - input_dim = self._layers[0].kernel.numpy().shape[0] - loss = self.loss - if distribution == _accepted_distributions[0]: # laplace - per_class_epsilon = self.epsilon / (self.n_classes) - l2_sensitivity = (2 * - loss.lipchitz_constant(self.class_weight)) / \ - (loss.gamma() * data_size) - unit_vector = tf.random.normal(shape=(input_dim, self.n_classes), - mean=0, - seed=1, - stddev=1.0, - dtype=self._dtype) - unit_vector = unit_vector / tf.math.sqrt( - tf.reduce_sum(tf.math.square(unit_vector), axis=0) - ) + PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/') + BUFFER_SIZE = 400 + BATCH_SIZE = 1 + IMG_WIDTH = 256 + IMG_HEIGHT = 256 - beta = l2_sensitivity / per_class_epsilon - alpha = input_dim # input_dim - gamma = tf.random.gamma([self.n_classes], - alpha, - beta=1 / beta, - seed=1, - dtype=self._dtype - ) - return unit_vector * gamma - raise NotImplementedError('Noise distribution: {0} is not ' - 'a valid distribution'.format(distribution)) + + def load(image_file): + image = tf.io.read_file(image_file) + image = tf.image.decode_jpeg(image) + + w = tf.shape(image)[1] + + w = w // 2 + real_image = image[:, :w, :] + input_image = image[:, w:, :] + + input_image = tf.cast(input_image, tf.float32) + real_image = tf.cast(real_image, tf.float32) + + return input_image, real_image + + + inp, re = load(PATH + 'train/100.jpg') + # casting to int for matplotlib to show the image + plt.figure() + plt.imshow(inp / 255.0) + plt.figure() + plt.imshow(re / 255.0) + + + def resize(input_image, real_image, height, width): + input_image = tf.image.resize(input_image, [height, width], + method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) + real_image = tf.image.resize(real_image, [height, width], + method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) + + return input_image, real_image + + + def random_crop(input_image, real_image): + stacked_image = tf.stack([input_image, real_image], axis=0) + cropped_image = tf.image.random_crop( + stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3]) + + return cropped_image[0], cropped_image[1] + + + def normalize(input_image, real_image): + input_image = (input_image / 127.5) - 1 + real_image = (real_image / 127.5) - 1 + + return input_image, real_image + + + @tf.function() + def random_jitter(input_image, real_image): + # resizing to 286 x 286 x 3 + input_image, real_image = resize(input_image, real_image, 286, 286) + + # randomly cropping to 256 x 256 x 3 + input_image, real_image = random_crop(input_image, real_image) + + if tf.random.uniform(()) > 0.5: + # random mirroring + input_image = tf.image.flip_left_right(input_image) + real_image = tf.image.flip_left_right(real_image) + + return input_image, real_image + + + def load_image_train(image_file): + input_image, real_image = load(image_file) + input_image, real_image = random_jitter(input_image, real_image) + input_image, real_image = normalize(input_image, real_image) + + return input_image, real_image + + + def load_image_test(image_file): + input_image, real_image = load(image_file) + input_image, real_image = resize(input_image, real_image, + IMG_HEIGHT, IMG_WIDTH) + input_image, real_image = normalize(input_image, real_image) + + return input_image, real_image + + + train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg') + train_dataset = train_dataset.shuffle(BUFFER_SIZE) + train_dataset = train_dataset.map(load_image_train, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + train_dataset = train_dataset.batch(1) + # steps_per_epoch = training_utils.infer_steps_for_dataset( + # train_dataset, None, epochs=1, steps_name='steps') + + # for batch in train_dataset: + # print(batch[1].shape) + test_dataset = tf.data.Dataset.list_files(PATH + 'test/*.jpg') + # shuffling so that for every epoch a different image is generated + # to predict and display the progress of our model. + train_dataset = train_dataset.shuffle(BUFFER_SIZE) + test_dataset = test_dataset.map(load_image_test) + test_dataset = test_dataset.batch(1) + + be = BoltonModel(3, 2) + from tensorflow.python.keras.optimizer_v2 import adam + from privacy.bolton import loss + + test = adam.Adam() + l = loss.StrongConvexBinaryCrossentropy(1, 2, 1) + be.compile(test, l) + print("Eager exeuction: {0}".format(tf.executing_eagerly())) + be.fit(train_dataset, verbose=0, steps_per_epoch=1, n_samples=1) diff --git a/privacy/bolton/model_test.py b/privacy/bolton/model_test.py index c3ca109..53c4c45 100644 --- a/privacy/bolton/model_test.py +++ b/privacy/bolton/model_test.py @@ -32,7 +32,7 @@ from absl.testing import absltest from tensorflow.python.keras.regularizers import L1L2 -class TestLoss(losses.Loss): +class TestLoss(losses.Loss, StrongConvexMixin): """Test loss function for testing Bolton model""" def __init__(self, reg_lambda, C, radius_constant, name='test'): super(TestLoss, self).__init__(name=name) @@ -145,21 +145,25 @@ class InitTests(keras_parameterized.TestCase): self.assertIsInstance(clf, model.Bolton) @parameterized.named_parameters([ - {'testcase_name': 'invalid noise', - 'n_classes': 1, - 'epsilon': 1, - 'noise_distribution': 'not_valid', - 'weights_initializer': tf.initializers.GlorotUniform(), - }, - {'testcase_name': 'invalid epsilon', - 'n_classes': 1, - 'epsilon': -1, - 'noise_distribution': 'laplace', - 'weights_initializer': tf.initializers.GlorotUniform(), - }, + {'testcase_name': 'invalid noise', + 'n_classes': 1, + 'epsilon': 1, + 'noise_distribution': 'not_valid', + 'weights_initializer': tf.initializers.GlorotUniform(), + }, + {'testcase_name': 'invalid epsilon', + 'n_classes': 1, + 'epsilon': -1, + 'noise_distribution': 'laplace', + 'weights_initializer': tf.initializers.GlorotUniform(), + }, ]) def test_bad_init_params( - self, n_classes, epsilon, noise_distribution, weights_initializer): + self, + n_classes, + epsilon, + noise_distribution, + weights_initializer): # test invalid domains for each variable, especially noise seed = 1 with self.assertRaises(ValueError): @@ -204,16 +208,16 @@ class InitTests(keras_parameterized.TestCase): self.assertEqual(clf.loss, loss) @parameterized.named_parameters([ - {'testcase_name': 'Not strong loss', - 'n_classes': 1, - 'loss': losses.BinaryCrossentropy(), - 'optimizer': 'adam', - }, - {'testcase_name': 'Not valid optimizer', - 'n_classes': 1, - 'loss': TestLoss(1, 1, 1), - 'optimizer': 'ada', - } + {'testcase_name': 'Not strong loss', + 'n_classes': 1, + 'loss': losses.BinaryCrossentropy(), + 'optimizer': 'adam', + }, + {'testcase_name': 'Not valid optimizer', + 'n_classes': 1, + 'loss': TestLoss(1, 1, 1), + 'optimizer': 'ada', + } ]) def test_bad_compile(self, n_classes, loss, optimizer): # test compilaton of invalid tf.optimizer and non instantiated loss. @@ -250,7 +254,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False): y_stack = [] for i_class in range(n_classes): x_stack.append( - tf.constant(1*i_class, tf.float32, (n_samples, input_dim)) + tf.constant(1*i_class, tf.float32, (n_samples, input_dim)) ) y_stack.append( tf.constant(i_class, tf.float32, (n_samples, n_classes)) @@ -258,7 +262,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False): x_set, y_set = tf.stack(x_stack), tf.stack(y_stack) if generator: dataset = tf.data.Dataset.from_tensor_slices( - (x_set, y_set) + (x_set, y_set) ) return dataset return x_set, y_set @@ -281,10 +285,10 @@ def _do_fit(n_samples, clf.compile(optimizer, loss) if generator: x = _cat_dataset( - n_samples, - input_dim, - n_classes, - generator=generator + n_samples, + input_dim, + n_classes, + generator=generator ) y = None # x = x.batch(batch_size) @@ -315,26 +319,26 @@ class FitTests(keras_parameterized.TestCase): # @test_util.run_all_in_graph_and_eager_modes @parameterized.named_parameters([ - {'testcase_name': 'iterator fit', - 'generator': False, - 'reset_n_samples': True, - 'callbacks': None - }, - {'testcase_name': 'iterator fit no samples', - 'generator': False, - 'reset_n_samples': True, - 'callbacks': None - }, - {'testcase_name': 'generator fit', - 'generator': True, - 'reset_n_samples': False, - 'callbacks': None - }, - {'testcase_name': 'with callbacks', - 'generator': True, - 'reset_n_samples': False, - 'callbacks': TestCallback() - }, + {'testcase_name': 'iterator fit', + 'generator': False, + 'reset_n_samples': True, + 'callbacks': None + }, + {'testcase_name': 'iterator fit no samples', + 'generator': False, + 'reset_n_samples': True, + 'callbacks': None + }, + {'testcase_name': 'generator fit', + 'generator': True, + 'reset_n_samples': False, + 'callbacks': None + }, + {'testcase_name': 'with callbacks', + 'generator': True, + 'reset_n_samples': False, + 'callbacks': TestCallback() + }, ]) def test_fit(self, generator, reset_n_samples, callbacks): loss = TestLoss(1, 1, 1) @@ -344,9 +348,19 @@ class FitTests(keras_parameterized.TestCase): epsilon = 1 batch_size = 1 n_samples = 10 - clf = _do_fit(n_samples, input_dim, n_classes, epsilon, generator, batch_size, - reset_n_samples, optimizer, loss, callbacks) - self.assertEqual(hasattr(clf, '_layers'), True) + clf = _do_fit( + n_samples, + input_dim, + n_classes, + epsilon, + generator, + batch_size, + reset_n_samples, + optimizer, + loss, + callbacks + ) + self.assertEqual(hasattr(clf, 'layers'), True) @parameterized.named_parameters([ {'testcase_name': 'generator fit', @@ -368,15 +382,15 @@ class FitTests(keras_parameterized.TestCase): ) clf.compile(optimizer, loss) x = _cat_dataset( - n_samples, - input_dim, - n_classes, - generator=generator + n_samples, + input_dim, + n_classes, + generator=generator ) x = x.batch(batch_size) x = x.shuffle(n_samples // 2) clf.fit_generator(x, n_samples=n_samples) - self.assertEqual(hasattr(clf, '_layers'), True) + self.assertEqual(hasattr(clf, 'layers'), True) @parameterized.named_parameters([ {'testcase_name': 'iterator no n_samples', @@ -399,32 +413,43 @@ class FitTests(keras_parameterized.TestCase): epsilon = 1 batch_size = 1 n_samples = 10 - _do_fit(n_samples, input_dim, n_classes, epsilon, generator, batch_size, - reset_n_samples, optimizer, loss, None, distribution) + _do_fit( + n_samples, + input_dim, + n_classes, + epsilon, + generator, + batch_size, + reset_n_samples, + optimizer, + loss, + None, + distribution + ) @parameterized.named_parameters([ - {'testcase_name': 'None class_weights', - 'class_weights': None, - 'class_counts': None, - 'num_classes': None, - 'result': 1}, - {'testcase_name': 'class weights array', - 'class_weights': [1, 1], - 'class_counts': [1, 1], - 'num_classes': 2, - 'result': [1, 1]}, - {'testcase_name': 'class weights balanced', - 'class_weights': 'balanced', - 'class_counts': [1, 1], - 'num_classes': 2, - 'result': [1, 1]}, + {'testcase_name': 'None class_weights', + 'class_weights': None, + 'class_counts': None, + 'num_classes': None, + 'result': 1}, + {'testcase_name': 'class weights array', + 'class_weights': [1, 1], + 'class_counts': [1, 1], + 'num_classes': 2, + 'result': [1, 1]}, + {'testcase_name': 'class weights balanced', + 'class_weights': 'balanced', + 'class_counts': [1, 1], + 'num_classes': 2, + 'result': [1, 1]}, ]) def test_class_calculate(self, class_weights, class_counts, num_classes, result - ): + ): clf = model.Bolton(1, 1) expected = clf.calculate_class_weights(class_weights, class_counts, @@ -447,14 +472,14 @@ class FitTests(keras_parameterized.TestCase): 'class_weights': 'balanced', 'class_counts': None, 'num_classes': 1, - 'err_msg': - "Class counts must be provided if using class_weights=balanced"}, + 'err_msg': "Class counts must be provided if " + "using class_weights=balanced"}, {'testcase_name': 'no num classes', 'class_weights': 'balanced', 'class_counts': [1], 'num_classes': None, - 'err_msg': - 'num_classes must be provided if using class_weights=balanced'}, + 'err_msg': 'num_classes must be provided if ' + 'using class_weights=balanced'}, {'testcase_name': 'class counts not array', 'class_weights': 'balanced', 'class_counts': 1, @@ -464,7 +489,7 @@ class FitTests(keras_parameterized.TestCase): 'class_weights': [1], 'class_counts': None, 'num_classes': None, - 'err_msg': "You must pass a value for num_classes if" + 'err_msg': "You must pass a value for num_classes if " "creating an array of class_weights"}, {'testcase_name': 'class counts array, improper shape', 'class_weights': [[1], [1]], @@ -481,7 +506,8 @@ class FitTests(keras_parameterized.TestCase): class_weights, class_counts, num_classes, - err_msg): + err_msg + ): clf = model.Bolton(1, 1) with self.assertRaisesRegexp(ValueError, err_msg): expected = clf.calculate_class_weights(class_weights, diff --git a/privacy/bolton/optimizer.py b/privacy/bolton/optimizer.py index 3b836ee..1ec25b9 100644 --- a/privacy/bolton/optimizer.py +++ b/privacy/bolton/optimizer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Private Optimizer for bolton method""" +"""Bolton Optimizer for bolton method""" from __future__ import absolute_import from __future__ import division @@ -19,15 +19,16 @@ from __future__ import print_function import tensorflow as tf from tensorflow.python.keras.optimizer_v2 import optimizer_v2 +from privacy.bolton.loss import StrongConvexMixin -_private_attributes = ['_internal_optimizer', 'dtype'] +_accepted_distributions = ['laplace'] -class Private(optimizer_v2.OptimizerV2): +class Bolton(optimizer_v2.OptimizerV2): """ - Private optimizer wraps another tf optimizer to be used + Bolton optimizer wraps another tf optimizer to be used as the visible optimizer to the tf model. No matter the optimizer - passed, "Private" enables the bolton model to control the learning rate + passed, "Bolton" enables the bolton model to control the learning rate based on the strongly convex loss. For more details on the strong convexity requirements, see: @@ -36,7 +37,8 @@ class Private(optimizer_v2.OptimizerV2): """ def __init__(self, optimizer: optimizer_v2.OptimizerV2, - dtype=tf.float32 + loss: StrongConvexMixin, + dtype=tf.float32, ): """Constructor. @@ -44,15 +46,100 @@ class Private(optimizer_v2.OptimizerV2): optimizer: Optimizer_v2 or subclass to be used as the optimizer (wrapped). """ + + if not isinstance(loss, StrongConvexMixin): + raise ValueError("loss function must be a Strongly Convex and therfore" + "extend the StrongConvexMixin.") + self._private_attributes = ['_internal_optimizer', + 'dtype', + 'noise_distribution', + 'epsilon', + 'loss', + 'class_weights', + 'input_dim', + 'n_samples', + 'n_classes', + 'layers', + '_model' + ] self._internal_optimizer = optimizer self.dtype = dtype + self.loss = loss def get_config(self): """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ return self._internal_optimizer.get_config() - def limit_learning_rate(self, is_eager, beta, gamma): + def project_weights_to_r(self, force=False): + """helper method to normalize the weights to the R-ball. + + Args: + r: radius of "R-Ball". Scalar to normalize to. + force: True to normalize regardless of previous weight values. + False to check if weights > R-ball and only normalize then. + + Returns: + + """ + r = self.loss.radius() + for layer in self.layers: + if tf.executing_eagerly(): + weight_norm = tf.norm(layer.kernel, axis=0) + if force: + layer.kernel = layer.kernel / (weight_norm / r) + elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self.dtype)) > 0: + layer.kernel = layer.kernel / (weight_norm / r) + else: + weight_norm = tf.norm(layer.kernel, axis=0) + if force: + layer.kernel = layer.kernel / (weight_norm / r) + else: + layer.kernel = tf.cond( + tf.reduce_sum(tf.cast(weight_norm > r, dtype=self.dtype)) > 0, + lambda: layer.kernel / (weight_norm / r), + lambda: layer.kernel + ) + + def get_noise(self, data_size, input_dim, output_dim, class_weight): + """Sample noise to be added to weights for privacy guarantee + + Args: + distribution: the distribution type to pull noise from + data_size: the number of samples + + Returns: noise in shape of layer's weights to be added to the weights. + + """ + loss = self.loss + distribution = self.noise_distribution.lower() + if distribution == _accepted_distributions[0]: # laplace + per_class_epsilon = self.epsilon / (output_dim) + l2_sensitivity = (2 * + loss.lipchitz_constant(class_weight)) / \ + (loss.gamma() * data_size) + unit_vector = tf.random.normal(shape=(input_dim, output_dim), + mean=0, + seed=1, + stddev=1.0, + dtype=self.dtype) + unit_vector = unit_vector / tf.math.sqrt( + tf.reduce_sum(tf.math.square(unit_vector), axis=0) + ) + + beta = l2_sensitivity / per_class_epsilon + alpha = input_dim # input_dim + gamma = tf.random.gamma([output_dim], + alpha, + beta=1 / beta, + seed=1, + dtype=self.dtype + ) + return unit_vector * gamma + raise NotImplementedError('Noise distribution: {0} is not ' + 'a valid distribution'.format(distribution)) + + def limit_learning_rate(self, beta, gamma): """Implements learning rate limitation that is required by the bolton method for sensitivity bounding of the strongly convex function. Sets the learning rate to the min(1/beta, 1/(gamma*t)) @@ -65,20 +152,13 @@ class Private(optimizer_v2.OptimizerV2): Returns: None """ - numerator = tf.Variable(initial_value=1, dtype=self.dtype) + numerator = tf.constant(1, dtype=self.dtype) t = tf.cast(self._iterations, self.dtype) # will exist on the internal optimizer - pred = numerator / beta < numerator / (gamma * t) - if is_eager: # check eagerly - if pred: - self.learning_rate = numerator / beta - else: - self.learning_rate = numerator / (gamma * t) + if numerator / beta < numerator / (gamma * t): + self.learning_rate = numerator / beta else: - if pred: - self.learning_rate = numerator / beta - else: - self.learning_rate = numerator / (gamma * t) + self.learning_rate = numerator / (gamma * t) def from_config(self, *args, **kwargs): """Reroutes to _internal_optimizer. See super/_internal_optimizer. @@ -92,14 +172,25 @@ class Private(optimizer_v2.OptimizerV2): Args: name: - Returns: attribute from Private if specified to come from self, else + Returns: attribute from Bolton if specified to come from self, else from _internal_optimizer. """ - if name in _private_attributes: + if name == '_private_attributes': + return getattr(self, name) + elif name in self._private_attributes: return getattr(self, name) optim = object.__getattribute__(self, '_internal_optimizer') - return object.__getattribute__(optim, name) + try: + return object.__getattribute__(optim, name) + except AttributeError: + raise AttributeError("Neither '{0}' nor '{1}' object has attribute '{2}'" + "".format( + self.__class__.__name__, + self._internal_optimizer.__class__.__name__, + name + ) + ) def __setattr__(self, key, value): """ Set attribute to self instance if its the internal optimizer. @@ -112,7 +203,9 @@ class Private(optimizer_v2.OptimizerV2): Returns: """ - if key in _private_attributes: + if key == '_private_attributes': + object.__setattr__(self, key, value) + elif key in key in self._private_attributes: object.__setattr__(self, key, value) else: setattr(self._internal_optimizer, key, value) @@ -130,24 +223,135 @@ class Private(optimizer_v2.OptimizerV2): def get_updates(self, loss, params): """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ - return self._internal_optimizer.get_updates(loss, params) + # self.layers = params + out = self._internal_optimizer.get_updates(loss, params) + self.limit_learning_rate(self.loss.beta(self.class_weights), + self.loss.gamma() + ) + self.project_weights_to_r() + return out def apply_gradients(self, *args, **kwargs): """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ - return self._internal_optimizer.apply_gradients(*args, **kwargs) + # grads_and_vars = kwargs.get('grads_and_vars', None) + # grads_and_vars = optimizer_v2._filter_grads(grads_and_vars) + # var_list = [v for (_, v) in grads_and_vars] + # self.layers = var_list + out = self._internal_optimizer.apply_gradients(*args, **kwargs) + self.limit_learning_rate(self.loss.beta(self.class_weights), + self.loss.gamma() + ) + self.project_weights_to_r() + return out def minimize(self, *args, **kwargs): """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ - return self._internal_optimizer.minimize(*args, **kwargs) + # self.layers = kwargs.get('var_list', None) + out = self._internal_optimizer.minimize(*args, **kwargs) + self.limit_learning_rate(self.loss.beta(self.class_weights), + self.loss.gamma() + ) + self.project_weights_to_r() + return out def _compute_gradients(self, *args, **kwargs): """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ + # self.layers = kwargs.get('var_list', None) return self._internal_optimizer._compute_gradients(*args, **kwargs) def get_gradients(self, *args, **kwargs): """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ + # self.layers = kwargs.get('params', None) return self._internal_optimizer.get_gradients(*args, **kwargs) + + def __enter__(self): + noise_distribution = self.noise_distribution + epsilon = self.epsilon + class_weights = self.class_weights + n_samples = self.n_samples + if noise_distribution not in _accepted_distributions: + raise ValueError('Detected noise distribution: {0} not one of: {1} valid' + 'distributions'.format(noise_distribution, + _accepted_distributions)) + self.noise_distribution = noise_distribution + self.epsilon = epsilon + self.class_weights = class_weights + self.n_samples = n_samples + return self + + def __call__(self, + noise_distribution, + epsilon, + layers, + class_weights, + n_samples, + n_classes, + ): + """ + + Args: + noise_distribution: the noise distribution to pick. + see _accepted_distributions and get_noise for + possible values. + epsilon: privacy parameter. Lower gives more privacy but less utility. + class_weights: class_weights used + n_samples number of rows/individual samples in the training set + n_classes: number of output classes + layers: list of Keras/Tensorflow layers. + """ + if epsilon <= 0: + raise ValueError('Detected epsilon: {0}. ' + 'Valid range is 0 < epsilon Date: Mon, 17 Jun 2019 13:25:30 -0400 Subject: [PATCH 4/9] Bolton created as optimizer with context manager usage. Unit tests included. Additional loss functions TBD. --- privacy/bolton/loss.py | 28 +- privacy/bolton/loss_test.py | 6 +- privacy/bolton/model.py | 432 +++++++------------------------ privacy/bolton/model_test.py | 252 +++++++++--------- privacy/bolton/optimizer.py | 297 ++++++++++++--------- privacy/bolton/optimizer_test.py | 331 +++++++++++++++++++---- 6 files changed, 685 insertions(+), 661 deletions(-) diff --git a/privacy/bolton/loss.py b/privacy/bolton/loss.py index 5cc029a..de49607 100644 --- a/privacy/bolton/loss.py +++ b/privacy/bolton/loss.py @@ -102,7 +102,7 @@ class StrongConvexMixin: return tf.math.reduce_max(class_weight) -class StrongConvexHuber(losses.Huber, StrongConvexMixin): +class StrongConvexHuber(losses.Loss, StrongConvexMixin): """Strong Convex version of Huber loss using l2 weight regularization. """ @@ -112,7 +112,6 @@ class StrongConvexHuber(losses.Huber, StrongConvexMixin): radius_constant: float, delta: float, reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, - name: str = 'huber', dtype=tf.float32): """Constructor. @@ -137,13 +136,17 @@ class StrongConvexHuber(losses.Huber, StrongConvexMixin): raise ValueError('radius_constant: {0}, should be >= 0'.format( radius_constant )) - self.C = C + if delta <= 0: + raise ValueError('delta: {0}, should be >= 0'.format( + delta + )) + self.C = C # pylint: disable=invalid-name + self.delta = delta self.radius_constant = radius_constant self.dtype = dtype self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) super(StrongConvexHuber, self).__init__( - delta=delta, - name=name, + name='huber', reduction=reduction, ) @@ -151,26 +154,25 @@ class StrongConvexHuber(losses.Huber, StrongConvexMixin): """Compute loss Args: - y_true: Ground truth values. + y_true: Ground truth values. One y_pred: The predicted values. Returns: Loss values per sample. """ # return super(StrongConvexHuber, self).call(y_true, y_pred) * self._sample_weight - h = self._fn_kwargs['delta'] + h = self.delta z = y_pred * y_true one = tf.constant(1, dtype=self.dtype) four = tf.constant(4, dtype=self.dtype) if z > one + h: - return z - z + return _ops.convert_to_tensor_v2(0, dtype=self.dtype) elif tf.math.abs(one - z) <= h: return one / (four * h) * tf.math.pow(one + h - z, 2) elif z < one - h: return one - z - else: - raise ValueError('') + raise ValueError('') # shouldn't be possible to get here. def radius(self): """See super class. @@ -186,7 +188,7 @@ class StrongConvexHuber(losses.Huber, StrongConvexMixin): """See super class. """ max_class_weight = self.max_class_weight(class_weight, self.dtype) - delta = _ops.convert_to_tensor_v2(self._fn_kwargs['delta'], + delta = _ops.convert_to_tensor_v2(self.delta, dtype=self.dtype ) return self.C * max_class_weight / (delta * @@ -250,7 +252,7 @@ class StrongConvexBinaryCrossentropy( radius_constant )) self.dtype = dtype - self.C = C + self.C = C # pylint: disable=invalid-name self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) super(StrongConvexBinaryCrossentropy, self).__init__( reduction=reduction, @@ -306,7 +308,7 @@ class StrongConvexBinaryCrossentropy( this loss function to be strongly convex. :return: """ - return L1L2(l2=self.reg_lambda) + return L1L2(l2=self.reg_lambda/2) # class StrongConvexSparseCategoricalCrossentropy( diff --git a/privacy/bolton/loss_test.py b/privacy/bolton/loss_test.py index ddb4861..488710f 100644 --- a/privacy/bolton/loss_test.py +++ b/privacy/bolton/loss_test.py @@ -79,7 +79,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase): 'reg_lambda': 1, 'C': 1, 'radius_constant': 1 - }, + }, # pylint: disable=invalid-name ]) def test_init_params(self, reg_lambda, C, radius_constant): """Test initialization for given arguments @@ -107,7 +107,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase): 'reg_lambda': -1, 'C': 1, 'radius_constant': 1 - }, + }, # pylint: disable=invalid-name ]) def test_bad_init_params(self, reg_lambda, C, radius_constant): """Test invalid domain for given params. Should return ValueError @@ -180,7 +180,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase): 'fn': 'kernel_regularizer', 'init_args': [1, 1, 1], 'args': [], - 'result': L1L2(l2=1), + 'result': L1L2(l2=0.5), }, ]) def test_fns(self, init_args, fn, args, result): diff --git a/privacy/bolton/model.py b/privacy/bolton/model.py index a6731fe..6f3f48e 100644 --- a/privacy/bolton/model.py +++ b/privacy/bolton/model.py @@ -23,8 +23,6 @@ from tensorflow.python.framework import ops as _ops from privacy.bolton.loss import StrongConvexMixin from privacy.bolton.optimizer import Bolton -_accepted_distributions = ['laplace'] - class BoltonModel(Model): """ @@ -41,41 +39,28 @@ class BoltonModel(Model): """ def __init__(self, - n_classes, - # noise_distribution='laplace', + n_outputs, seed=1, dtype=tf.float32 ): """ private constructor. Args: - n_classes: number of output classes to predict. - epsilon: level of privacy guarantee - noise_distribution: distribution to pull weight perturbations from - weights_initializer: initializer for weights + n_outputs: number of output classes to predict. seed: random seed to use dtype: data type to use for tensors """ - - # if noise_distribution not in _accepted_distributions: - # raise ValueError('Detected noise distribution: {0} not one of: {1} valid' - # 'distributions'.format(noise_distribution, - # _accepted_distributions)) - # if epsilon <= 0: - # raise ValueError('Detected epsilon: {0}. ' - # 'Valid range is 0 < epsilon 0.'.format( + n_outputs + )) + self.n_outputs = n_outputs self.seed = seed - self.__in_fit = False self._layers_instantiated = False - # self._callback = MyCustomCallback() self._dtype = dtype - def call(self, inputs): + def call(self, inputs, training=False): # pylint: disable=arguments-differ """Forward pass of network Args: @@ -87,37 +72,30 @@ class BoltonModel(Model): return self.output_layer(inputs) def compile(self, - optimizer='SGD', - loss=None, + optimizer, + loss, metrics=None, loss_weights=None, sample_weight_mode=None, weighted_metrics=None, target_tensors=None, distribute=None, - **kwargs): + kernel_initializer=tf.initializers.GlorotUniform, + **kwargs): # pylint: disable=arguments-differ """See super class. Default optimizer used in Bolton method is SGD. """ - for key, val in StrongConvexMixin.__dict__.items(): - if callable(val) and getattr(loss, key, None) is None: - raise ValueError("Please ensure you are passing a valid StrongConvex " - "loss that has all the required methods " - "implemented. " - "Required method: {0} not found".format(key)) + if not isinstance(loss, StrongConvexMixin): + raise ValueError("loss function must be a Strongly Convex and therefore " + "extend the StrongConvexMixin.") if not self._layers_instantiated: # compile may be called multiple times - kernel_intiializer = kwargs.get('kernel_initializer', - tf.initializers.GlorotUniform) + # for instance, if the input/outputs are not defined until fit. self.output_layer = tf.keras.layers.Dense( - self.n_classes, + self.n_outputs, kernel_regularizer=loss.kernel_regularizer(), - kernel_initializer=kernel_intiializer(), + kernel_initializer=kernel_initializer(), ) - # if we don't do regularization here, we require the user to - # re-instantiate the model each time they want to change the penalty - # weighting self._layers_instantiated = True - self.output_layer.kernel_regularizer.l2 = loss.reg_lambda if not isinstance(optimizer, Bolton): optimizer = optimizers.get(optimizer) optimizer = Bolton(optimizer, loss) @@ -133,69 +111,16 @@ class BoltonModel(Model): **kwargs ) - # def _post_fit(self, x, n_samples): - # """Implements 1-time weight changes needed for Bolton method. - # In this case, specifically implements the noise addition - # assuming a strongly convex function. - # - # Args: - # x: inputs - # n_samples: number of samples in the inputs. In case the number - # cannot be readily determined by inspecting x. - # - # Returns: - # - # """ - # data_size = None - # if n_samples is not None: - # data_size = n_samples - # elif hasattr(x, 'shape'): - # data_size = x.shape[0] - # elif hasattr(x, "__len__"): - # data_size = len(x) - # elif data_size is None: - # if n_samples is None: - # raise ValueError("Unable to detect the number of training " - # "samples and n_smaples was None. " - # "either pass a dataset with a .shape or " - # "__len__ attribute or explicitly pass the " - # "number of samples as n_smaples.") - # for layer in self.layers: - # # layer.kernel = layer.kernel + self._get_noise( - # # data_size - # # ) - # input_dim = layer.kernel.numpy().shape[0] - # layer.kernel = layer.kernel + self.optimizer.get_noise( - # self.loss, - # data_size, - # input_dim, - # self.n_classes, - # self.class_weight - # ) - def fit(self, x=None, y=None, batch_size=None, - epochs=1, - verbose=1, - callbacks=None, - validation_split=0.0, - validation_data=None, - shuffle=True, class_weight=None, - sample_weight=None, - initial_epoch=0, - steps_per_epoch=None, - validation_steps=None, - validation_freq=1, - max_queue_size=10, - workers=1, - use_multiprocessing=False, n_samples=None, epsilon=2, noise_distribution='laplace', - **kwargs): + steps_per_epoch=None, + **kwargs): # pylint: disable=arguments-differ """Reroutes to super fit with additional Bolton delta-epsilon privacy requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1 Requirements are as follows: @@ -207,92 +132,101 @@ class BoltonModel(Model): Args: n_samples: the number of individual samples in x. + epsilon: privacy parameter, which trades off between utility an privacy. + See the bolton paper for more description. + noise_distribution: the distribution to pull noise from. + class_weight: the class weights to be used. Can be a scalar or 1D tensor + whose dim == n_classes. - Returns: + See the super method for descriptions on the rest of the arguments. """ - self.__in_fit = True - # cb = [self.optimizer.callbacks] - # if callbacks is not None: - # cb.extend(callbacks) - # callbacks = cb if class_weight is None: class_weight = self.calculate_class_weights(class_weight) - # self.class_weight = class_weight + if n_samples is not None: + data_size = n_samples + elif hasattr(x, 'shape'): + data_size = x.shape[0] + elif hasattr(x, "__len__"): + data_size = len(x) + else: + data_size = None + batch_size_ = self._validate_or_infer_batch_size(batch_size, + steps_per_epoch, + x + ) + # inferring batch_size to be passed to optimizer. batch_size must remain its + # initial value when passed to super().fit() + if batch_size_ is None: + raise ValueError('batch_size: {0} is an ' + 'invalid value'.format(batch_size_)) with self.optimizer(noise_distribution, epsilon, self.layers, class_weight, - n_samples, - self.n_classes, - ) as optim: + data_size, + self.n_outputs, + batch_size_, + ) as _: out = super(BoltonModel, self).fit(x=x, y=y, batch_size=batch_size, - epochs=epochs, - verbose=verbose, - callbacks=callbacks, - validation_split=validation_split, - validation_data=validation_data, - shuffle=shuffle, class_weight=class_weight, - sample_weight=sample_weight, - initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, - validation_steps=validation_steps, - validation_freq=validation_freq, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing, **kwargs ) return out def fit_generator(self, generator, - steps_per_epoch=None, - epochs=1, - verbose=1, - callbacks=None, - validation_data=None, - validation_steps=None, - validation_freq=1, class_weight=None, - max_queue_size=10, - workers=1, - use_multiprocessing=False, - shuffle=True, - initial_epoch=0, - n_samples=None - ): + noise_distribution='laplace', + epsilon=2, + n_samples=None, + steps_per_epoch=None, + **kwargs + ): # pylint: disable=arguments-differ """ This method is the same as fit except for when the passed dataset is a generator. See super method and fit for more details. Args: n_samples: number of individual samples in x + noise_distribution: the distribution to get noise from. + epsilon: privacy parameter, which trades off utility and privacy. See + Bolton paper for more description. + class_weight: the class weights to be used. Can be a scalar or 1D tensor + whose dim == n_classes. + See the super method for descriptions on the rest of the arguments. """ if class_weight is None: class_weight = self.calculate_class_weights(class_weight) - self.class_weight = class_weight - out = super(BoltonModel, self).fit_generator( - generator, - steps_per_epoch=steps_per_epoch, - epochs=epochs, - verbose=verbose, - callbacks=callbacks, - validation_data=validation_data, - validation_steps=validation_steps, - validation_freq=validation_freq, - class_weight=class_weight, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing, - shuffle=shuffle, - initial_epoch=initial_epoch - ) - if not self.__in_fit: - self._post_fit(generator, n_samples) + if n_samples is not None: + data_size = n_samples + elif hasattr(generator, 'shape'): + data_size = generator.shape[0] + elif hasattr(generator, "__len__"): + data_size = len(generator) + else: + data_size = None + batch_size = self._validate_or_infer_batch_size(None, + steps_per_epoch, + generator + ) + with self.optimizer(noise_distribution, + epsilon, + self.layers, + class_weight, + data_size, + self.n_outputs, + batch_size + ) as _: + out = super(BoltonModel, self).fit_generator( + generator, + class_weight=class_weight, + steps_per_epoch=steps_per_epoch, + **kwargs + ) return out def calculate_class_weights(self, @@ -336,7 +270,7 @@ class BoltonModel(Model): "class_weights=%s" % class_weights) elif class_weights is not None: if num_classes is None: - raise ValueError("You must pass a value for num_classes if" + raise ValueError("You must pass a value for num_classes if " "creating an array of class_weights") # performing class weight calculation if class_weights is None: @@ -357,195 +291,9 @@ class BoltonModel(Model): "1D array".format(class_weights.shape)) if class_weights.shape[0] != num_classes: raise ValueError( - "Detected array length: {0} instead of: {1}".format( - class_weights.shape[0], - num_classes - ) + "Detected array length: {0} instead of: {1}".format( + class_weights.shape[0], + num_classes + ) ) return class_weights - - # def _project_weights_to_r(self, r, force=False): - # """helper method to normalize the weights to the R-ball. - # - # Args: - # r: radius of "R-Ball". Scalar to normalize to. - # force: True to normalize regardless of previous weight values. - # False to check if weights > R-ball and only normalize then. - # - # Returns: - # - # """ - # for layer in self.layers: - # weight_norm = tf.norm(layer.kernel, axis=0) - # if force: - # layer.kernel = layer.kernel / (weight_norm / r) - # elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0: - # layer.kernel = layer.kernel / (weight_norm / r) - - # def _get_noise(self, distribution, data_size): - # """Sample noise to be added to weights for privacy guarantee - # - # Args: - # distribution: the distribution type to pull noise from - # data_size: the number of samples - # - # Returns: noise in shape of layer's weights to be added to the weights. - # - # """ - # distribution = distribution.lower() - # input_dim = self.layers[0].kernel.numpy().shape[0] - # loss = self.loss - # if distribution == _accepted_distributions[0]: # laplace - # per_class_epsilon = self.epsilon / (self.n_classes) - # l2_sensitivity = (2 * - # loss.lipchitz_constant(self.class_weight)) / \ - # (loss.gamma() * data_size) - # unit_vector = tf.random.normal(shape=(input_dim, self.n_classes), - # mean=0, - # seed=1, - # stddev=1.0, - # dtype=self._dtype) - # unit_vector = unit_vector / tf.math.sqrt( - # tf.reduce_sum(tf.math.square(unit_vector), axis=0) - # ) - # - # beta = l2_sensitivity / per_class_epsilon - # alpha = input_dim # input_dim - # gamma = tf.random.gamma([self.n_classes], - # alpha, - # beta=1 / beta, - # seed=1, - # dtype=self._dtype - # ) - # return unit_vector * gamma - # raise NotImplementedError('Noise distribution: {0} is not ' - # 'a valid distribution'.format(distribution)) - - -if __name__ == '__main__': - import tensorflow as tf - - import os - import time - import matplotlib.pyplot as plt - - _URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz' - - path_to_zip = tf.keras.utils.get_file('facades.tar.gz', - origin=_URL, - extract=True) - - PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/') - BUFFER_SIZE = 400 - BATCH_SIZE = 1 - IMG_WIDTH = 256 - IMG_HEIGHT = 256 - - - def load(image_file): - image = tf.io.read_file(image_file) - image = tf.image.decode_jpeg(image) - - w = tf.shape(image)[1] - - w = w // 2 - real_image = image[:, :w, :] - input_image = image[:, w:, :] - - input_image = tf.cast(input_image, tf.float32) - real_image = tf.cast(real_image, tf.float32) - - return input_image, real_image - - - inp, re = load(PATH + 'train/100.jpg') - # casting to int for matplotlib to show the image - plt.figure() - plt.imshow(inp / 255.0) - plt.figure() - plt.imshow(re / 255.0) - - - def resize(input_image, real_image, height, width): - input_image = tf.image.resize(input_image, [height, width], - method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) - real_image = tf.image.resize(real_image, [height, width], - method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) - - return input_image, real_image - - - def random_crop(input_image, real_image): - stacked_image = tf.stack([input_image, real_image], axis=0) - cropped_image = tf.image.random_crop( - stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3]) - - return cropped_image[0], cropped_image[1] - - - def normalize(input_image, real_image): - input_image = (input_image / 127.5) - 1 - real_image = (real_image / 127.5) - 1 - - return input_image, real_image - - - @tf.function() - def random_jitter(input_image, real_image): - # resizing to 286 x 286 x 3 - input_image, real_image = resize(input_image, real_image, 286, 286) - - # randomly cropping to 256 x 256 x 3 - input_image, real_image = random_crop(input_image, real_image) - - if tf.random.uniform(()) > 0.5: - # random mirroring - input_image = tf.image.flip_left_right(input_image) - real_image = tf.image.flip_left_right(real_image) - - return input_image, real_image - - - def load_image_train(image_file): - input_image, real_image = load(image_file) - input_image, real_image = random_jitter(input_image, real_image) - input_image, real_image = normalize(input_image, real_image) - - return input_image, real_image - - - def load_image_test(image_file): - input_image, real_image = load(image_file) - input_image, real_image = resize(input_image, real_image, - IMG_HEIGHT, IMG_WIDTH) - input_image, real_image = normalize(input_image, real_image) - - return input_image, real_image - - - train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg') - train_dataset = train_dataset.shuffle(BUFFER_SIZE) - train_dataset = train_dataset.map(load_image_train, - num_parallel_calls=tf.data.experimental.AUTOTUNE) - train_dataset = train_dataset.batch(1) - # steps_per_epoch = training_utils.infer_steps_for_dataset( - # train_dataset, None, epochs=1, steps_name='steps') - - # for batch in train_dataset: - # print(batch[1].shape) - test_dataset = tf.data.Dataset.list_files(PATH + 'test/*.jpg') - # shuffling so that for every epoch a different image is generated - # to predict and display the progress of our model. - train_dataset = train_dataset.shuffle(BUFFER_SIZE) - test_dataset = test_dataset.map(load_image_test) - test_dataset = test_dataset.batch(1) - - be = BoltonModel(3, 2) - from tensorflow.python.keras.optimizer_v2 import adam - from privacy.bolton import loss - - test = adam.Adam() - l = loss.StrongConvexBinaryCrossentropy(1, 2, 1) - be.compile(test, l) - print("Eager exeuction: {0}".format(tf.executing_eagerly())) - be.fit(train_dataset, verbose=0, steps_per_epoch=1, n_samples=1) diff --git a/privacy/bolton/model_test.py b/privacy/bolton/model_test.py index 53c4c45..4316a1e 100644 --- a/privacy/bolton/model_test.py +++ b/privacy/bolton/model_test.py @@ -19,25 +19,22 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.python.platform import test from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2 from tensorflow.python.keras import losses from tensorflow.python.framework import ops as _ops -from tensorflow.python.framework import test_util -from privacy.bolton import model -from privacy.bolton.loss import StrongConvexMixin -from absl.testing import parameterized -from absl.testing import absltest from tensorflow.python.keras.regularizers import L1L2 - +from absl.testing import parameterized +from privacy.bolton import model +from privacy.bolton.optimizer import Bolton +from privacy.bolton.loss import StrongConvexMixin class TestLoss(losses.Loss, StrongConvexMixin): """Test loss function for testing Bolton model""" def __init__(self, reg_lambda, C, radius_constant, name='test'): super(TestLoss, self).__init__(name=name) self.reg_lambda = reg_lambda - self.C = C + self.C = C # pylint: disable=invalid-name self.radius_constant = radius_constant def radius(self): @@ -78,13 +75,17 @@ class TestLoss(losses.Loss, StrongConvexMixin): """ return _ops.convert_to_tensor_v2(1, dtype=tf.float32) - def call(self, val0, val1): + def call(self, y_true, y_pred): """Loss function that is minimized at the mean of the input points.""" - return 0.5 * tf.reduce_sum(tf.math.squared_difference(val0, val1), axis=1) + return 0.5 * tf.reduce_sum( + tf.math.squared_difference(y_true, y_pred), + axis=1 + ) def max_class_weight(self, class_weight): if class_weight is None: return 1 + raise ValueError('') def kernel_regularizer(self): return L1L2(l2=self.reg_lambda) @@ -116,125 +117,91 @@ class InitTests(keras_parameterized.TestCase): @parameterized.named_parameters([ {'testcase_name': 'normal', - 'n_classes': 1, - 'epsilon': 1, - 'noise_distribution': 'laplace', - 'seed': 1 + 'n_outputs': 1, }, - {'testcase_name': 'extreme range', - 'n_classes': 5, - 'epsilon': 0.1, - 'noise_distribution': 'laplace', - 'seed': 10 - }, - {'testcase_name': 'extreme range2', - 'n_classes': 50, - 'epsilon': 10, - 'noise_distribution': 'laplace', - 'seed': 100 + {'testcase_name': 'many outputs', + 'n_outputs': 100, }, ]) - def test_init_params( - self, n_classes, epsilon, noise_distribution, seed): + def test_init_params(self, n_outputs): + """test initialization of BoltonModel + + Args: + n_outputs: number of output neurons + """ # test valid domains for each variable - clf = model.Bolton(n_classes, - epsilon, - noise_distribution, - seed - ) - self.assertIsInstance(clf, model.Bolton) + clf = model.BoltonModel(n_outputs) + self.assertIsInstance(clf, model.BoltonModel) @parameterized.named_parameters([ - {'testcase_name': 'invalid noise', - 'n_classes': 1, - 'epsilon': 1, - 'noise_distribution': 'not_valid', - 'weights_initializer': tf.initializers.GlorotUniform(), - }, - {'testcase_name': 'invalid epsilon', - 'n_classes': 1, - 'epsilon': -1, - 'noise_distribution': 'laplace', - 'weights_initializer': tf.initializers.GlorotUniform(), + {'testcase_name': 'invalid n_outputs', + 'n_outputs': -1, }, ]) - def test_bad_init_params( - self, - n_classes, - epsilon, - noise_distribution, - weights_initializer): + def test_bad_init_params(self, n_outputs): + """test bad initializations of BoltonModel that should raise errors + + Args: + n_outputs: number of output neurons + """ # test invalid domains for each variable, especially noise - seed = 1 with self.assertRaises(ValueError): - clf = model.Bolton(n_classes, - epsilon, - noise_distribution, - weights_initializer, - seed - ) + model.BoltonModel(n_outputs) @parameterized.named_parameters([ {'testcase_name': 'string compile', - 'n_classes': 1, + 'n_outputs': 1, 'loss': TestLoss(1, 1, 1), 'optimizer': 'adam', - 'weights_initializer': tf.initializers.GlorotUniform(), }, {'testcase_name': 'test compile', - 'n_classes': 100, + 'n_outputs': 100, 'loss': TestLoss(1, 1, 1), 'optimizer': TestOptimizer(), - 'weights_initializer': tf.initializers.GlorotUniform(), - }, - {'testcase_name': 'invalid weights initializer', - 'n_classes': 1, - 'loss': TestLoss(1, 1, 1), - 'optimizer': TestOptimizer(), - 'weights_initializer': 'not_valid', }, ]) - def test_compile(self, n_classes, loss, optimizer, weights_initializer): + def test_compile(self, n_outputs, loss, optimizer): + """test compilation of BoltonModel + + Args: + n_outputs: number of output neurons + loss: instantiated TestLoss instance + optimizer: instanced TestOptimizer instance + """ # test compilation of valid tf.optimizer and tf.loss - epsilon = 1 - noise_distribution = 'laplace' with self.cached_session(): - clf = model.Bolton(n_classes, - epsilon, - noise_distribution, - weights_initializer - ) + clf = model.BoltonModel(n_outputs) clf.compile(optimizer, loss) self.assertEqual(clf.loss, loss) @parameterized.named_parameters([ {'testcase_name': 'Not strong loss', - 'n_classes': 1, + 'n_outputs': 1, 'loss': losses.BinaryCrossentropy(), 'optimizer': 'adam', }, {'testcase_name': 'Not valid optimizer', - 'n_classes': 1, + 'n_outputs': 1, 'loss': TestLoss(1, 1, 1), 'optimizer': 'ada', } ]) - def test_bad_compile(self, n_classes, loss, optimizer): + def test_bad_compile(self, n_outputs, loss, optimizer): + """test bad compilations of BoltonModel that should raise errors + + Args: + n_outputs: number of output neurons + loss: instantiated TestLoss instance + optimizer: instanced TestOptimizer instance + """ # test compilaton of invalid tf.optimizer and non instantiated loss. - epsilon = 1 - noise_distribution = 'laplace' - weights_initializer = tf.initializers.GlorotUniform() with self.cached_session(): with self.assertRaises((ValueError, AttributeError)): - clf = model.Bolton(n_classes, - epsilon, - noise_distribution, - weights_initializer - ) + clf = model.BoltonModel(n_outputs) clf.compile(optimizer, loss) -def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False): +def _cat_dataset(n_samples, input_dim, n_classes, generator=False): """ Creates a categorically encoded dataset (y is categorical). returns the specified dataset either as a static array or as a generator. @@ -245,10 +212,9 @@ def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False): n_samples: number of rows input_dim: input dimensionality n_classes: output dimensionality - t: one of 'train', 'val', 'test' generator: False for array, True for generator Returns: - X as (n_samples, input_dim), Y as (n_samples, n_classes) + X as (n_samples, input_dim), Y as (n_samples, n_outputs) """ x_stack = [] y_stack = [] @@ -269,25 +235,39 @@ def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False): def _do_fit(n_samples, input_dim, - n_classes, + n_outputs, epsilon, generator, batch_size, reset_n_samples, optimizer, loss, - callbacks, distribution='laplace'): - clf = model.Bolton(n_classes, - epsilon, - distribution - ) + """Helper to instantiate necessary components for fitting and perform a model + fit. + + Args: + n_samples: number of samples in dataset + input_dim: the sample dimensionality + n_outputs: number of output neurons + epsilon: privacy parameter + generator: True to create a generator, False to use an iterator + batch_size: batch_size to use + reset_n_samples: True to set _samples to None prior to fitting. + False does nothing + optimizer: instance of TestOptimizer + loss: instance of TestLoss + distribution: distribution to get noise from. + + Returns: BoltonModel instsance + """ + clf = model.BoltonModel(n_outputs) clf.compile(optimizer, loss) if generator: x = _cat_dataset( n_samples, input_dim, - n_classes, + n_outputs, generator=generator ) y = None @@ -295,25 +275,20 @@ def _do_fit(n_samples, x = x.shuffle(n_samples//2) batch_size = None else: - x, y = _cat_dataset(n_samples, input_dim, n_classes, generator=generator) + x, y = _cat_dataset(n_samples, input_dim, n_outputs, generator=generator) if reset_n_samples: n_samples = None - if callbacks is not None: - callbacks = [callbacks] clf.fit(x, y, batch_size=batch_size, n_samples=n_samples, - callbacks=callbacks + noise_distribution=distribution, + epsilon=epsilon ) return clf -class TestCallback(tf.keras.callbacks.Callback): - pass - - class FitTests(keras_parameterized.TestCase): """Test cases for keras model fitting""" @@ -322,27 +297,29 @@ class FitTests(keras_parameterized.TestCase): {'testcase_name': 'iterator fit', 'generator': False, 'reset_n_samples': True, - 'callbacks': None }, {'testcase_name': 'iterator fit no samples', 'generator': False, 'reset_n_samples': True, - 'callbacks': None }, {'testcase_name': 'generator fit', 'generator': True, 'reset_n_samples': False, - 'callbacks': None }, {'testcase_name': 'with callbacks', 'generator': True, 'reset_n_samples': False, - 'callbacks': TestCallback() }, ]) - def test_fit(self, generator, reset_n_samples, callbacks): + def test_fit(self, generator, reset_n_samples): + """Tests fitting of BoltonModel + + Args: + generator: True for generator test, False for iterator test. + reset_n_samples: True to reset the n_samples to None, False does nothing + """ loss = TestLoss(1, 1, 1) - optimizer = TestOptimizer() + optimizer = Bolton(TestOptimizer(), loss) n_classes = 2 input_dim = 5 epsilon = 1 @@ -358,28 +335,27 @@ class FitTests(keras_parameterized.TestCase): reset_n_samples, optimizer, loss, - callbacks ) self.assertEqual(hasattr(clf, 'layers'), True) @parameterized.named_parameters([ {'testcase_name': 'generator fit', 'generator': True, - 'reset_n_samples': False, - 'callbacks': None }, ]) - def test_fit_gen(self, generator, reset_n_samples, callbacks): + def test_fit_gen(self, generator): + """Tests the fit_generator method of BoltonModel + + Args: + generator: True to test with a generator dataset + """ loss = TestLoss(1, 1, 1) optimizer = TestOptimizer() n_classes = 2 input_dim = 5 - epsilon = 1 batch_size = 1 n_samples = 10 - clf = model.Bolton(n_classes, - epsilon - ) + clf = model.BoltonModel(n_classes) clf.compile(optimizer, loss) x = _cat_dataset( n_samples, @@ -405,6 +381,14 @@ class FitTests(keras_parameterized.TestCase): }, ]) def test_bad_fit(self, generator, reset_n_samples, distribution): + """Tests fitting with invalid parameters, which should raise an error + + Args: + generator: True to test with generator, False is iterator + reset_n_samples: True to reset the n_samples param to None prior to + passing it to fit + distribution: distribution to get noise from. + """ with self.assertRaises(ValueError): loss = TestLoss(1, 1, 1) optimizer = TestOptimizer() @@ -423,7 +407,6 @@ class FitTests(keras_parameterized.TestCase): reset_n_samples, optimizer, loss, - None, distribution ) @@ -450,7 +433,15 @@ class FitTests(keras_parameterized.TestCase): num_classes, result ): - clf = model.Bolton(1, 1) + """Tests the BOltonModel calculate_class_weights method + + Args: + class_weights: the class_weights to use + class_counts: count of number of samples for each class + num_classes: number of outputs neurons + result: expected result + """ + clf = model.BoltonModel(1, 1) expected = clf.calculate_class_weights(class_weights, class_counts, num_classes @@ -508,12 +499,21 @@ class FitTests(keras_parameterized.TestCase): num_classes, err_msg ): - clf = model.Bolton(1, 1) - with self.assertRaisesRegexp(ValueError, err_msg): - expected = clf.calculate_class_weights(class_weights, - class_counts, - num_classes - ) + """Tests the BOltonModel calculate_class_weights method with invalid params + which should raise the expected errors. + + Args: + class_weights: the class_weights to use + class_counts: count of number of samples for each class + num_classes: number of outputs neurons + result: expected result + """ + clf = model.BoltonModel(1, 1) + with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method + clf.calculate_class_weights(class_weights, + class_counts, + num_classes + ) if __name__ == '__main__': diff --git a/privacy/bolton/optimizer.py b/privacy/bolton/optimizer.py index 1ec25b9..cfd0b98 100644 --- a/privacy/bolton/optimizer.py +++ b/privacy/bolton/optimizer.py @@ -19,9 +19,74 @@ from __future__ import print_function import tensorflow as tf from tensorflow.python.keras.optimizer_v2 import optimizer_v2 +from tensorflow.python.ops import math_ops +from tensorflow.python import ops as _ops from privacy.bolton.loss import StrongConvexMixin -_accepted_distributions = ['laplace'] +_accepted_distributions = ['laplace'] # implemented distributions for noising + + +class GammaBetaDecreasingStep( + optimizer_v2.learning_rate_schedule.LearningRateSchedule +): + """ + Learning Rate Scheduler using the minimum of 1/beta and 1/(gamma * step) + at each step. A required step for privacy guarantees. + """ + def __init__(self): + self.is_init = False + self.beta = None + self.gamma = None + + def __call__(self, step): + """ + returns the learning rate + Args: + step: the current iteration number + Returns: + decayed learning rate to minimum of 1/beta and 1/(gamma * step) as per + the Bolton privacy requirements. + """ + if not self.is_init: + raise AttributeError('Please initialize the {0} Learning Rate Scheduler.' + 'This is performed automatically by using the ' + '{1} as a context manager, ' + 'as desired'.format(self.__class__.__name__, + Bolton.__class__.__name__ + ) + ) + dtype = self.beta.dtype + one = tf.constant(1, dtype) + return tf.math.minimum(tf.math.reduce_min(one/self.beta), + one/(self.gamma*math_ops.cast(step, dtype)) + ) + + def get_config(self): + """ + config to setup the learning rate scheduler. + """ + return {'beta': self.beta, 'gamma': self.gamma} + + def initialize(self, beta, gamma): + """setup the learning rate scheduler with the beta and gamma values provided + by the loss function. Meant to be used with .fit as the loss params may + depend on values passed to fit. + + Args: + beta: Smoothness value. See StrongConvexMixin + gamma: Strong Convexity parameter. See StrongConvexMixin. + """ + self.is_init = True + self.beta = beta + self.gamma = gamma + + def de_initialize(self): + """De initialize the scheduler after fitting, in case another fit call has + different loss parameters. + """ + self.is_init = False + self.beta = None + self.gamma = None class Bolton(optimizer_v2.OptimizerV2): @@ -31,11 +96,24 @@ class Bolton(optimizer_v2.OptimizerV2): passed, "Bolton" enables the bolton model to control the learning rate based on the strongly convex loss. + To use the Bolton method, you must: + 1. instantiate it with an instantiated tf optimizer and StrongConvexLoss. + 2. use it as a context manager around your .fit method internals. + + This can be accomplished by the following: + optimizer = tf.optimizers.SGD() + loss = privacy.bolton.losses.StrongConvexBinaryCrossentropy() + bolton = Bolton(optimizer, loss) + with bolton(*args) as _: + model.fit() + The args required for the context manager can be found in the __call__ + method. + For more details on the strong convexity requirements, see: Bolt-on Differential Privacy for Scalable Stochastic Gradient Descent-based Analytics by Xi Wu et. al. """ - def __init__(self, + def __init__(self, # pylint: disable=super-init-not-called optimizer: optimizer_v2.OptimizerV2, loss: StrongConvexMixin, dtype=tf.float32, @@ -45,10 +123,11 @@ class Bolton(optimizer_v2.OptimizerV2): Args: optimizer: Optimizer_v2 or subclass to be used as the optimizer (wrapped). + loss: StrongConvexLoss function that the model is being compiled with. """ if not isinstance(loss, StrongConvexMixin): - raise ValueError("loss function must be a Strongly Convex and therfore" + raise ValueError("loss function must be a Strongly Convex and therefore " "extend the StrongConvexMixin.") self._private_attributes = ['_internal_optimizer', 'dtype', @@ -58,13 +137,19 @@ class Bolton(optimizer_v2.OptimizerV2): 'class_weights', 'input_dim', 'n_samples', - 'n_classes', + 'n_outputs', 'layers', - '_model' + 'batch_size', + '_is_init' ] self._internal_optimizer = optimizer + self.learning_rate = GammaBetaDecreasingStep() # use the Bolton Learning + # rate scheduler, as required for privacy guarantees. This will still need + # to get values from the loss function near the time that .fit is called + # on the model (when this optimizer will be called as a context manager) self.dtype = dtype self.loss = loss + self._is_init = False def get_config(self): """Reroutes to _internal_optimizer. See super/_internal_optimizer. @@ -75,49 +160,44 @@ class Bolton(optimizer_v2.OptimizerV2): """helper method to normalize the weights to the R-ball. Args: - r: radius of "R-Ball". Scalar to normalize to. force: True to normalize regardless of previous weight values. False to check if weights > R-ball and only normalize then. Returns: """ - r = self.loss.radius() + radius = self.loss.radius() for layer in self.layers: - if tf.executing_eagerly(): - weight_norm = tf.norm(layer.kernel, axis=0) - if force: - layer.kernel = layer.kernel / (weight_norm / r) - elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self.dtype)) > 0: - layer.kernel = layer.kernel / (weight_norm / r) + weight_norm = tf.norm(layer.kernel, axis=0) + if force: + layer.kernel = layer.kernel / (weight_norm / radius) else: - weight_norm = tf.norm(layer.kernel, axis=0) - if force: - layer.kernel = layer.kernel / (weight_norm / r) - else: - layer.kernel = tf.cond( - tf.reduce_sum(tf.cast(weight_norm > r, dtype=self.dtype)) > 0, - lambda: layer.kernel / (weight_norm / r), - lambda: layer.kernel - ) + layer.kernel = tf.cond( + tf.reduce_sum(tf.cast(weight_norm > radius, dtype=self.dtype)) > 0, + lambda k=layer.kernel, w=weight_norm, r=radius: k / (w / r), # pylint: disable=cell-var-from-loop + lambda k=layer.kernel: k # pylint: disable=cell-var-from-loop + ) - def get_noise(self, data_size, input_dim, output_dim, class_weight): + def get_noise(self, input_dim, output_dim): """Sample noise to be added to weights for privacy guarantee Args: - distribution: the distribution type to pull noise from - data_size: the number of samples + input_dim: the input dimensionality for the weights + output_dim the output dimensionality for the weights Returns: noise in shape of layer's weights to be added to the weights. """ + if not self._is_init: + raise Exception('This method must be called from within the optimizer\'s ' + 'context.') loss = self.loss distribution = self.noise_distribution.lower() if distribution == _accepted_distributions[0]: # laplace per_class_epsilon = self.epsilon / (output_dim) l2_sensitivity = (2 * - loss.lipchitz_constant(class_weight)) / \ - (loss.gamma() * data_size) + loss.lipchitz_constant(self.class_weights)) / \ + (loss.gamma() * self.n_samples * self.batch_size) unit_vector = tf.random.normal(shape=(input_dim, output_dim), mean=0, seed=1, @@ -139,28 +219,7 @@ class Bolton(optimizer_v2.OptimizerV2): raise NotImplementedError('Noise distribution: {0} is not ' 'a valid distribution'.format(distribution)) - def limit_learning_rate(self, beta, gamma): - """Implements learning rate limitation that is required by the bolton - method for sensitivity bounding of the strongly convex function. - Sets the learning rate to the min(1/beta, 1/(gamma*t)) - - Args: - is_eager: Whether the model is running in eager mode - beta: loss function beta-smoothness - gamma: loss function gamma-strongly convex - - Returns: None - - """ - numerator = tf.constant(1, dtype=self.dtype) - t = tf.cast(self._iterations, self.dtype) - # will exist on the internal optimizer - if numerator / beta < numerator / (gamma * t): - self.learning_rate = numerator / beta - else: - self.learning_rate = numerator / (gamma * t) - - def from_config(self, *args, **kwargs): + def from_config(self, *args, **kwargs): # pylint: disable=arguments-differ """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ return self._internal_optimizer.from_config(*args, **kwargs) @@ -176,21 +235,19 @@ class Bolton(optimizer_v2.OptimizerV2): from _internal_optimizer. """ - if name == '_private_attributes': - return getattr(self, name) - elif name in self._private_attributes: + if name == '_private_attributes' or name in self._private_attributes: return getattr(self, name) optim = object.__getattribute__(self, '_internal_optimizer') try: return object.__getattribute__(optim, name) except AttributeError: - raise AttributeError("Neither '{0}' nor '{1}' object has attribute '{2}'" - "".format( - self.__class__.__name__, - self._internal_optimizer.__class__.__name__, - name - ) - ) + raise AttributeError( + "Neither '{0}' nor '{1}' object has attribute '{2}'" + "".format(self.__class__.__name__, + self._internal_optimizer.__class__.__name__, + name + ) + ) def __setattr__(self, key, value): """ Set attribute to self instance if its the internal optimizer. @@ -205,113 +262,110 @@ class Bolton(optimizer_v2.OptimizerV2): """ if key == '_private_attributes': object.__setattr__(self, key, value) - elif key in key in self._private_attributes: + elif key in self._private_attributes: object.__setattr__(self, key, value) else: setattr(self._internal_optimizer, key, value) - def _resource_apply_dense(self, *args, **kwargs): + def _resource_apply_dense(self, *args, **kwargs): # pylint: disable=arguments-differ """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ - return self._internal_optimizer._resource_apply_dense(*args, **kwargs) + return self._internal_optimizer._resource_apply_dense(*args, **kwargs) # pylint: disable=protected-access - def _resource_apply_sparse(self, *args, **kwargs): + def _resource_apply_sparse(self, *args, **kwargs): # pylint: disable=arguments-differ """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ - return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) + return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) # pylint: disable=protected-access def get_updates(self, loss, params): """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ - # self.layers = params out = self._internal_optimizer.get_updates(loss, params) - self.limit_learning_rate(self.loss.beta(self.class_weights), - self.loss.gamma() - ) self.project_weights_to_r() return out - def apply_gradients(self, *args, **kwargs): + def apply_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ - # grads_and_vars = kwargs.get('grads_and_vars', None) - # grads_and_vars = optimizer_v2._filter_grads(grads_and_vars) - # var_list = [v for (_, v) in grads_and_vars] - # self.layers = var_list out = self._internal_optimizer.apply_gradients(*args, **kwargs) - self.limit_learning_rate(self.loss.beta(self.class_weights), - self.loss.gamma() - ) self.project_weights_to_r() return out - def minimize(self, *args, **kwargs): + def minimize(self, *args, **kwargs): # pylint: disable=arguments-differ """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ - # self.layers = kwargs.get('var_list', None) out = self._internal_optimizer.minimize(*args, **kwargs) - self.limit_learning_rate(self.loss.beta(self.class_weights), - self.loss.gamma() - ) self.project_weights_to_r() return out - def _compute_gradients(self, *args, **kwargs): + def _compute_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ,protected-access """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ - # self.layers = kwargs.get('var_list', None) - return self._internal_optimizer._compute_gradients(*args, **kwargs) + return self._internal_optimizer._compute_gradients(*args, **kwargs) # pylint: disable=protected-access - def get_gradients(self, *args, **kwargs): + def get_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ """Reroutes to _internal_optimizer. See super/_internal_optimizer. """ - # self.layers = kwargs.get('params', None) return self._internal_optimizer.get_gradients(*args, **kwargs) def __enter__(self): - noise_distribution = self.noise_distribution - epsilon = self.epsilon - class_weights = self.class_weights - n_samples = self.n_samples - if noise_distribution not in _accepted_distributions: - raise ValueError('Detected noise distribution: {0} not one of: {1} valid' - 'distributions'.format(noise_distribution, - _accepted_distributions)) - self.noise_distribution = noise_distribution - self.epsilon = epsilon - self.class_weights = class_weights - self.n_samples = n_samples + """Context manager call at the beginning of with statement. + + Returns: + self, to be used in context manager + """ + self._is_init = True return self def __call__(self, - noise_distribution, - epsilon, - layers, + noise_distribution: str, + epsilon: float, + layers: list, class_weights, n_samples, - n_classes, + n_outputs, + batch_size ): - """ + """Entry point from context. Accepts required values for bolton method and + stores them on the optimizer for use throughout fitting. Args: noise_distribution: the noise distribution to pick. see _accepted_distributions and get_noise for possible values. epsilon: privacy parameter. Lower gives more privacy but less utility. - class_weights: class_weights used + layers: list of Keras/Tensorflow layers. Can be found as model.layers + class_weights: class_weights used, which may either be a scalar or 1D + tensor with dim == n_classes. n_samples number of rows/individual samples in the training set - n_classes: number of output classes - layers: list of Keras/Tensorflow layers. + n_outputs: number of output classes + batch_size: batch size used. """ if epsilon <= 0: raise ValueError('Detected epsilon: {0}. ' 'Valid range is 0 < epsilon Date: Mon, 17 Jun 2019 14:46:04 -0400 Subject: [PATCH 5/9] Update Huber loss regularization term and some small changes across loss parameters. --- privacy/bolton/loss.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/privacy/bolton/loss.py b/privacy/bolton/loss.py index de49607..4ed0479 100644 --- a/privacy/bolton/loss.py +++ b/privacy/bolton/loss.py @@ -58,7 +58,8 @@ class StrongConvexMixin: """Smoothness, beta Args: - class_weight: the class weights used. + class_weight: the class weights as scalar or 1d tensor, where its + dimensionality is equal to the number of outputs. Returns: Beta @@ -154,7 +155,7 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin): """Compute loss Args: - y_true: Ground truth values. One + y_true: Ground truth values. One hot encoded using -1 and 1. y_pred: The predicted values. Returns: @@ -211,7 +212,7 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin): this loss function to be strongly convex. :return: """ - return L1L2(l2=self.reg_lambda) + return L1L2(l2=self.reg_lambda/2) class StrongConvexBinaryCrossentropy( @@ -230,7 +231,6 @@ class StrongConvexBinaryCrossentropy( from_logits: bool = True, label_smoothing: float = 0, reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, - name: str = 'binarycrossentropy', dtype=tf.float32): """ Args: @@ -239,7 +239,9 @@ class StrongConvexBinaryCrossentropy( radius_constant: constant defining the length of the radius reduction: reduction type to use. See super class label_smoothing: amount of smoothing to perform on labels - relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) + relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x). + Note, the impact of this parameter's effect on privacy + is not known and thus the default should be used. name: Name of the loss instance dtype: tf datatype to use for tensor conversions. """ @@ -256,7 +258,7 @@ class StrongConvexBinaryCrossentropy( self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) super(StrongConvexBinaryCrossentropy, self).__init__( reduction=reduction, - name=name, + name='binarycrossentropy', from_logits=from_logits, label_smoothing=label_smoothing, ) From f41be2c598636f949b97244941ca5cd6ad27d31c Mon Sep 17 00:00:00 2001 From: Christopher Choquette Choo Date: Wed, 19 Jun 2019 10:46:30 -0400 Subject: [PATCH 6/9] Bolton implementation and unit tests. Has two pre-implemented loss functions. --- privacy/__init__.py | 6 + privacy/bolton/__init__.py | 7 +- privacy/bolton/{loss.py => losses.py} | 287 +----------------- .../bolton/{loss_test.py => losses_test.py} | 10 +- privacy/bolton/{model.py => models.py} | 13 +- .../bolton/{model_test.py => models_test.py} | 26 +- .../bolton/{optimizer.py => optimizers.py} | 2 +- .../{optimizer_test.py => optimizers_test.py} | 6 +- 8 files changed, 49 insertions(+), 308 deletions(-) rename privacy/bolton/{loss.py => losses.py} (51%) rename privacy/bolton/{loss_test.py => losses_test.py} (98%) rename privacy/bolton/{model.py => models.py} (96%) rename privacy/bolton/{model_test.py => models_test.py} (96%) rename privacy/bolton/{optimizer.py => optimizers.py} (99%) rename privacy/bolton/{optimizer_test.py => optimizers_test.py} (99%) diff --git a/privacy/__init__.py b/privacy/__init__.py index 59bfe20..e494c62 100644 --- a/privacy/__init__.py +++ b/privacy/__init__.py @@ -41,3 +41,9 @@ else: from privacy.optimizers.dp_optimizer import DPAdamOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer + + from privacy.bolton.models import BoltonModel + from privacy.bolton.optimizers import Bolton + from privacy.bolton.losses import StrongConvexMixin + from privacy.bolton.losses import StrongConvexBinaryCrossentropy + from privacy.bolton.losses import StrongConvexHuber diff --git a/privacy/bolton/__init__.py b/privacy/bolton/__init__.py index 67b6148..971b804 100644 --- a/privacy/bolton/__init__.py +++ b/privacy/bolton/__init__.py @@ -9,6 +9,7 @@ if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts. pass else: - from privacy.bolton.model import Bolton - from privacy.bolton.loss import StrongConvexHuber - from privacy.bolton.loss import StrongConvexBinaryCrossentropy \ No newline at end of file + from privacy.bolton.models import BoltonModel + from privacy.bolton.optimizers import Bolton + from privacy.bolton.losses import StrongConvexHuber + from privacy.bolton.losses import StrongConvexBinaryCrossentropy \ No newline at end of file diff --git a/privacy/bolton/loss.py b/privacy/bolton/losses.py similarity index 51% rename from privacy/bolton/loss.py rename to privacy/bolton/losses.py index 4ed0479..a326946 100644 --- a/privacy/bolton/loss.py +++ b/privacy/bolton/losses.py @@ -21,6 +21,7 @@ from tensorflow.python.keras import losses from tensorflow.python.keras.utils import losses_utils from tensorflow.python.framework import ops as _ops from tensorflow.python.keras.regularizers import L1L2 +from tensorflow.python.platform import tf_logging as logging class StrongConvexMixin: @@ -147,7 +148,7 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin): self.dtype = dtype self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) super(StrongConvexHuber, self).__init__( - name='huber', + name='strongconvexhuber', reduction=reduction, ) @@ -245,6 +246,11 @@ class StrongConvexBinaryCrossentropy( name: Name of the loss instance dtype: tf datatype to use for tensor conversions. """ + if label_smoothing != 0: + logging.warning('The impact of label smoothing on privacy is unknown. ' + 'Use label smoothing at your own risk as it may not ' + 'guarantee privacy.') + if reg_lambda <= 0: raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) if C <= 0: @@ -258,7 +264,7 @@ class StrongConvexBinaryCrossentropy( self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) super(StrongConvexBinaryCrossentropy, self).__init__( reduction=reduction, - name='binarycrossentropy', + name='strongconvexbinarycrossentropy', from_logits=from_logits, label_smoothing=label_smoothing, ) @@ -313,280 +319,3 @@ class StrongConvexBinaryCrossentropy( return L1L2(l2=self.reg_lambda/2) -# class StrongConvexSparseCategoricalCrossentropy( -# losses.CategoricalCrossentropy, -# StrongConvexMixin -# ): -# """ -# Strong Convex version of CategoricalCrossentropy loss using l2 weight -# regularization. -# """ -# -# def __init__(self, -# reg_lambda: float, -# C: float, -# radius_constant: float, -# from_logits: bool = True, -# label_smoothing: float = 0, -# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, -# name: str = 'binarycrossentropy', -# dtype=tf.float32): -# """ -# Args: -# reg_lambda: Weight regularization constant -# C: Penalty parameter C of the loss term -# radius_constant: constant defining the length of the radius -# reduction: reduction type to use. See super class -# label_smoothing: amount of smoothing to perform on labels -# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) -# name: Name of the loss instance -# dtype: tf datatype to use for tensor conversions. -# """ -# if reg_lambda <= 0: -# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) -# if C <= 0: -# raise ValueError('c: {0}, should be >= 0'.format(C)) -# if radius_constant <= 0: -# raise ValueError('radius_constant: {0}, should be >= 0'.format( -# radius_constant -# )) -# -# self.C = C -# self.dtype = dtype -# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) -# super(StrongConvexSparseCategoricalCrossentropy, self).__init__( -# reduction=reduction, -# name=name, -# from_logits=from_logits, -# label_smoothing=label_smoothing, -# ) -# self.radius_constant = radius_constant -# -# def call(self, y_true, y_pred): -# """Compute loss -# -# Args: -# y_true: Ground truth values. -# y_pred: The predicted values. -# -# Returns: -# Loss values per sample. -# """ -# loss = super() -# loss = loss * self.C -# return loss -# -# def radius(self): -# """See super class. -# """ -# return self.radius_constant / self.reg_lambda -# -# def gamma(self): -# """See super class. -# """ -# return self.reg_lambda -# -# def beta(self, class_weight): -# """See super class. -# """ -# max_class_weight = self.max_class_weight(class_weight, self.dtype) -# return self.C * max_class_weight + self.reg_lambda -# -# def lipchitz_constant(self, class_weight): -# """See super class. -# """ -# max_class_weight = self.max_class_weight(class_weight, self.dtype) -# return self.C * max_class_weight + self.reg_lambda * self.radius() -# -# def kernel_regularizer(self): -# """ -# l2 loss using reg_lambda as the l2 term (as desired). Required for -# this loss function to be strongly convex. -# :return: -# """ -# return L1L2(l2=self.reg_lambda) -# -# class StrongConvexSparseCategoricalCrossentropy( -# losses.SparseCategoricalCrossentropy, -# StrongConvexMixin -# ): -# """ -# Strong Convex version of SparseCategoricalCrossentropy loss using l2 weight -# regularization. -# """ -# -# def __init__(self, -# reg_lambda: float, -# C: float, -# radius_constant: float, -# from_logits: bool = True, -# label_smoothing: float = 0, -# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, -# name: str = 'binarycrossentropy', -# dtype=tf.float32): -# """ -# Args: -# reg_lambda: Weight regularization constant -# C: Penalty parameter C of the loss term -# radius_constant: constant defining the length of the radius -# reduction: reduction type to use. See super class -# label_smoothing: amount of smoothing to perform on labels -# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) -# name: Name of the loss instance -# dtype: tf datatype to use for tensor conversions. -# """ -# if reg_lambda <= 0: -# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) -# if C <= 0: -# raise ValueError('c: {0}, should be >= 0'.format(C)) -# if radius_constant <= 0: -# raise ValueError('radius_constant: {0}, should be >= 0'.format( -# radius_constant -# )) -# -# self.C = C -# self.dtype = dtype -# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) -# super(StrongConvexHuber, self).__init__(reduction=reduction, -# name=name, -# from_logits=from_logits, -# label_smoothing=label_smoothing, -# ) -# self.radius_constant = radius_constant -# -# def call(self, y_true, y_pred): -# """Compute loss -# -# Args: -# y_true: Ground truth values. -# y_pred: The predicted values. -# -# Returns: -# Loss values per sample. -# """ -# loss = super() -# loss = loss * self.C -# return loss -# -# def radius(self): -# """See super class. -# """ -# return self.radius_constant / self.reg_lambda -# -# def gamma(self): -# """See super class. -# """ -# return self.reg_lambda -# -# def beta(self, class_weight): -# """See super class. -# """ -# max_class_weight = self.max_class_weight(class_weight, self.dtype) -# return self.C * max_class_weight + self.reg_lambda -# -# def lipchitz_constant(self, class_weight): -# """See super class. -# """ -# max_class_weight = self.max_class_weight(class_weight, self.dtype) -# return self.C * max_class_weight + self.reg_lambda * self.radius() -# -# def kernel_regularizer(self): -# """ -# l2 loss using reg_lambda as the l2 term (as desired). Required for -# this loss function to be strongly convex. -# :return: -# """ -# return L1L2(l2=self.reg_lambda) -# -# -# class StrongConvexCategoricalCrossentropy( -# losses.CategoricalCrossentropy, -# StrongConvexMixin -# ): -# """ -# Strong Convex version of CategoricalCrossentropy loss using l2 weight -# regularization. -# """ -# -# def __init__(self, -# reg_lambda: float, -# C: float, -# radius_constant: float, -# from_logits: bool = True, -# label_smoothing: float = 0, -# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, -# name: str = 'binarycrossentropy', -# dtype=tf.float32): -# """ -# Args: -# reg_lambda: Weight regularization constant -# C: Penalty parameter C of the loss term -# radius_constant: constant defining the length of the radius -# reduction: reduction type to use. See super class -# label_smoothing: amount of smoothing to perform on labels -# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) -# name: Name of the loss instance -# dtype: tf datatype to use for tensor conversions. -# """ -# if reg_lambda <= 0: -# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) -# if C <= 0: -# raise ValueError('c: {0}, should be >= 0'.format(C)) -# if radius_constant <= 0: -# raise ValueError('radius_constant: {0}, should be >= 0'.format( -# radius_constant -# )) -# -# self.C = C -# self.dtype = dtype -# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) -# super(StrongConvexHuber, self).__init__(reduction=reduction, -# name=name, -# from_logits=from_logits, -# label_smoothing=label_smoothing, -# ) -# self.radius_constant = radius_constant -# -# def call(self, y_true, y_pred): -# """Compute loss -# -# Args: -# y_true: Ground truth values. -# y_pred: The predicted values. -# -# Returns: -# Loss values per sample. -# """ -# loss = super() -# loss = loss * self.C -# return loss -# -# def radius(self): -# """See super class. -# """ -# return self.radius_constant / self.reg_lambda -# -# def gamma(self): -# """See super class. -# """ -# return self.reg_lambda -# -# def beta(self, class_weight): -# """See super class. -# """ -# max_class_weight = self.max_class_weight(class_weight, self.dtype) -# return self.C * max_class_weight + self.reg_lambda -# -# def lipchitz_constant(self, class_weight): -# """See super class. -# """ -# max_class_weight = self.max_class_weight(class_weight, self.dtype) -# return self.C * max_class_weight + self.reg_lambda * self.radius() -# -# def kernel_regularizer(self): -# """ -# l2 loss using reg_lambda as the l2 term (as desired). Required for -# this loss function to be strongly convex. -# :return: -# """ -# return L1L2(l2=self.reg_lambda) diff --git a/privacy/bolton/loss_test.py b/privacy/bolton/losses_test.py similarity index 98% rename from privacy/bolton/loss_test.py rename to privacy/bolton/losses_test.py index 488710f..d2c9f80 100644 --- a/privacy/bolton/loss_test.py +++ b/privacy/bolton/losses_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Unit testing for loss.py""" +"""Unit testing for losses.py""" from __future__ import absolute_import from __future__ import division @@ -22,9 +22,9 @@ from tensorflow.python.keras import keras_parameterized from tensorflow.python.framework import test_util from tensorflow.python.keras.regularizers import L1L2 from absl.testing import parameterized -from privacy.bolton.loss import StrongConvexBinaryCrossentropy -from privacy.bolton.loss import StrongConvexHuber -from privacy.bolton.loss import StrongConvexMixin +from privacy.bolton.losses import StrongConvexBinaryCrossentropy +from privacy.bolton.losses import StrongConvexHuber +from privacy.bolton.losses import StrongConvexMixin class StrongConvexMixinTests(keras_parameterized.TestCase): @@ -355,7 +355,7 @@ class HuberTests(keras_parameterized.TestCase): 'fn': 'kernel_regularizer', 'init_args': [1, 1, 1, 1], 'args': [], - 'result': L1L2(l2=1), + 'result': L1L2(l2=0.5), }, ]) def test_fns(self, init_args, fn, args, result): diff --git a/privacy/bolton/model.py b/privacy/bolton/models.py similarity index 96% rename from privacy/bolton/model.py rename to privacy/bolton/models.py index 6f3f48e..0a2efc0 100644 --- a/privacy/bolton/model.py +++ b/privacy/bolton/models.py @@ -20,8 +20,8 @@ import tensorflow as tf from tensorflow.python.keras.models import Model from tensorflow.python.keras import optimizers from tensorflow.python.framework import ops as _ops -from privacy.bolton.loss import StrongConvexMixin -from privacy.bolton.optimizer import Bolton +from privacy.bolton.losses import StrongConvexMixin +from privacy.bolton.optimizers import Bolton class BoltonModel(Model): @@ -142,7 +142,9 @@ class BoltonModel(Model): """ if class_weight is None: - class_weight = self.calculate_class_weights(class_weight) + class_weight_ = self.calculate_class_weights(class_weight) + else: + class_weight_ = class_weight if n_samples is not None: data_size = n_samples elif hasattr(x, 'shape'): @@ -160,10 +162,13 @@ class BoltonModel(Model): if batch_size_ is None: raise ValueError('batch_size: {0} is an ' 'invalid value'.format(batch_size_)) + if data_size is None: + raise ValueError('Could not infer the number of samples. Please pass ' + 'this in using n_samples.') with self.optimizer(noise_distribution, epsilon, self.layers, - class_weight, + class_weight_, data_size, self.n_outputs, batch_size_, diff --git a/privacy/bolton/model_test.py b/privacy/bolton/models_test.py similarity index 96% rename from privacy/bolton/model_test.py rename to privacy/bolton/models_test.py index 4316a1e..05119d3 100644 --- a/privacy/bolton/model_test.py +++ b/privacy/bolton/models_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Unit testing for model.py""" +"""Unit testing for models.py""" from __future__ import absolute_import from __future__ import division @@ -25,9 +25,9 @@ from tensorflow.python.keras import losses from tensorflow.python.framework import ops as _ops from tensorflow.python.keras.regularizers import L1L2 from absl.testing import parameterized -from privacy.bolton import model -from privacy.bolton.optimizer import Bolton -from privacy.bolton.loss import StrongConvexMixin +from privacy.bolton import models +from privacy.bolton.optimizers import Bolton +from privacy.bolton.losses import StrongConvexMixin class TestLoss(losses.Loss, StrongConvexMixin): """Test loss function for testing Bolton model""" @@ -130,8 +130,8 @@ class InitTests(keras_parameterized.TestCase): n_outputs: number of output neurons """ # test valid domains for each variable - clf = model.BoltonModel(n_outputs) - self.assertIsInstance(clf, model.BoltonModel) + clf = models.BoltonModel(n_outputs) + self.assertIsInstance(clf, models.BoltonModel) @parameterized.named_parameters([ {'testcase_name': 'invalid n_outputs', @@ -146,7 +146,7 @@ class InitTests(keras_parameterized.TestCase): """ # test invalid domains for each variable, especially noise with self.assertRaises(ValueError): - model.BoltonModel(n_outputs) + models.BoltonModel(n_outputs) @parameterized.named_parameters([ {'testcase_name': 'string compile', @@ -170,7 +170,7 @@ class InitTests(keras_parameterized.TestCase): """ # test compilation of valid tf.optimizer and tf.loss with self.cached_session(): - clf = model.BoltonModel(n_outputs) + clf = models.BoltonModel(n_outputs) clf.compile(optimizer, loss) self.assertEqual(clf.loss, loss) @@ -197,7 +197,7 @@ class InitTests(keras_parameterized.TestCase): # test compilaton of invalid tf.optimizer and non instantiated loss. with self.cached_session(): with self.assertRaises((ValueError, AttributeError)): - clf = model.BoltonModel(n_outputs) + clf = models.BoltonModel(n_outputs) clf.compile(optimizer, loss) @@ -261,7 +261,7 @@ def _do_fit(n_samples, Returns: BoltonModel instsance """ - clf = model.BoltonModel(n_outputs) + clf = models.BoltonModel(n_outputs) clf.compile(optimizer, loss) if generator: x = _cat_dataset( @@ -355,7 +355,7 @@ class FitTests(keras_parameterized.TestCase): input_dim = 5 batch_size = 1 n_samples = 10 - clf = model.BoltonModel(n_classes) + clf = models.BoltonModel(n_classes) clf.compile(optimizer, loss) x = _cat_dataset( n_samples, @@ -441,7 +441,7 @@ class FitTests(keras_parameterized.TestCase): num_classes: number of outputs neurons result: expected result """ - clf = model.BoltonModel(1, 1) + clf = models.BoltonModel(1, 1) expected = clf.calculate_class_weights(class_weights, class_counts, num_classes @@ -508,7 +508,7 @@ class FitTests(keras_parameterized.TestCase): num_classes: number of outputs neurons result: expected result """ - clf = model.BoltonModel(1, 1) + clf = models.BoltonModel(1, 1) with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method clf.calculate_class_weights(class_weights, class_counts, diff --git a/privacy/bolton/optimizer.py b/privacy/bolton/optimizers.py similarity index 99% rename from privacy/bolton/optimizer.py rename to privacy/bolton/optimizers.py index cfd0b98..726ec4f 100644 --- a/privacy/bolton/optimizer.py +++ b/privacy/bolton/optimizers.py @@ -21,7 +21,7 @@ import tensorflow as tf from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import math_ops from tensorflow.python import ops as _ops -from privacy.bolton.loss import StrongConvexMixin +from privacy.bolton.losses import StrongConvexMixin _accepted_distributions = ['laplace'] # implemented distributions for noising diff --git a/privacy/bolton/optimizer_test.py b/privacy/bolton/optimizers_test.py similarity index 99% rename from privacy/bolton/optimizer_test.py rename to privacy/bolton/optimizers_test.py index 2060031..0a9f9cc 100644 --- a/privacy/bolton/optimizer_test.py +++ b/privacy/bolton/optimizers_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Unit testing for optimizer.py""" +"""Unit testing for optimizers.py""" from __future__ import absolute_import from __future__ import division @@ -29,8 +29,8 @@ from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import test_util from tensorflow.python import ops as _ops from absl.testing import parameterized -from privacy.bolton.loss import StrongConvexMixin -from privacy.bolton import optimizer as opt +from privacy.bolton.losses import StrongConvexMixin +from privacy.bolton import optimizers as opt From 56e16f0a15a531b0e9435cfd6c9d32f1d5be1d69 Mon Sep 17 00:00:00 2001 From: Christopher Choquette Choo Date: Wed, 19 Jun 2019 11:04:18 -0400 Subject: [PATCH 7/9] Minor changes + tutorial --- privacy/bolton/losses.py | 278 +++++++++++++++++++ privacy/bolton/models.py | 2 - privacy/bolton/optimizers.py | 24 +- privacy/bolton/optimizers_test.py | 4 +- tutorials/bolton_tutorial.ipynb | 432 ++++++++++++++++++++++++++++++ 5 files changed, 719 insertions(+), 21 deletions(-) create mode 100644 tutorials/bolton_tutorial.ipynb diff --git a/privacy/bolton/losses.py b/privacy/bolton/losses.py index a326946..6a54576 100644 --- a/privacy/bolton/losses.py +++ b/privacy/bolton/losses.py @@ -319,3 +319,281 @@ class StrongConvexBinaryCrossentropy( return L1L2(l2=self.reg_lambda/2) +# class StrongConvexSparseCategoricalCrossentropy( +# losses.CategoricalCrossentropy, +# StrongConvexMixin +# ): +# """ +# Strong Convex version of CategoricalCrossentropy loss using l2 weight +# regularization. +# """ +# +# def __init__(self, +# reg_lambda: float, +# C: float, +# radius_constant: float, +# from_logits: bool = True, +# label_smoothing: float = 0, +# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, +# name: str = 'binarycrossentropy', +# dtype=tf.float32): +# """ +# Args: +# reg_lambda: Weight regularization constant +# C: Penalty parameter C of the loss term +# radius_constant: constant defining the length of the radius +# reduction: reduction type to use. See super class +# label_smoothing: amount of smoothing to perform on labels +# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) +# name: Name of the loss instance +# dtype: tf datatype to use for tensor conversions. +# """ +# if reg_lambda <= 0: +# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) +# if C <= 0: +# raise ValueError('c: {0}, should be >= 0'.format(C)) +# if radius_constant <= 0: +# raise ValueError('radius_constant: {0}, should be >= 0'.format( +# radius_constant +# )) +# +# self.C = C +# self.dtype = dtype +# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) +# super(StrongConvexSparseCategoricalCrossentropy, self).__init__( +# reduction=reduction, +# name=name, +# from_logits=from_logits, +# label_smoothing=label_smoothing, +# ) +# self.radius_constant = radius_constant +# +# def call(self, y_true, y_pred): +# """Compute loss +# +# Args: +# y_true: Ground truth values. +# y_pred: The predicted values. +# +# Returns: +# Loss values per sample. +# """ +# loss = super() +# loss = loss * self.C +# return loss +# +# def radius(self): +# """See super class. +# """ +# return self.radius_constant / self.reg_lambda +# +# def gamma(self): +# """See super class. +# """ +# return self.reg_lambda +# +# def beta(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda +# +# def lipchitz_constant(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda * self.radius() +# +# def kernel_regularizer(self): +# """ +# l2 loss using reg_lambda as the l2 term (as desired). Required for +# this loss function to be strongly convex. +# :return: +# """ +# return L1L2(l2=self.reg_lambda) +# +# class StrongConvexSparseCategoricalCrossentropy( +# losses.SparseCategoricalCrossentropy, +# StrongConvexMixin +# ): +# """ +# Strong Convex version of SparseCategoricalCrossentropy loss using l2 weight +# regularization. +# """ +# +# def __init__(self, +# reg_lambda: float, +# C: float, +# radius_constant: float, +# from_logits: bool = True, +# label_smoothing: float = 0, +# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, +# name: str = 'binarycrossentropy', +# dtype=tf.float32): +# """ +# Args: +# reg_lambda: Weight regularization constant +# C: Penalty parameter C of the loss term +# radius_constant: constant defining the length of the radius +# reduction: reduction type to use. See super class +# label_smoothing: amount of smoothing to perform on labels +# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) +# name: Name of the loss instance +# dtype: tf datatype to use for tensor conversions. +# """ +# if reg_lambda <= 0: +# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) +# if C <= 0: +# raise ValueError('c: {0}, should be >= 0'.format(C)) +# if radius_constant <= 0: +# raise ValueError('radius_constant: {0}, should be >= 0'.format( +# radius_constant +# )) +# +# self.C = C +# self.dtype = dtype +# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) +# super(StrongConvexHuber, self).__init__(reduction=reduction, +# name=name, +# from_logits=from_logits, +# label_smoothing=label_smoothing, +# ) +# self.radius_constant = radius_constant +# +# def call(self, y_true, y_pred): +# """Compute loss +# +# Args: +# y_true: Ground truth values. +# y_pred: The predicted values. +# +# Returns: +# Loss values per sample. +# """ +# loss = super() +# loss = loss * self.C +# return loss +# +# def radius(self): +# """See super class. +# """ +# return self.radius_constant / self.reg_lambda +# +# def gamma(self): +# """See super class. +# """ +# return self.reg_lambda +# +# def beta(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda +# +# def lipchitz_constant(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda * self.radius() +# +# def kernel_regularizer(self): +# """ +# l2 loss using reg_lambda as the l2 term (as desired). Required for +# this loss function to be strongly convex. +# :return: +# """ +# return L1L2(l2=self.reg_lambda) +# +# +# class StrongConvexCategoricalCrossentropy( +# losses.CategoricalCrossentropy, +# StrongConvexMixin +# ): +# """ +# Strong Convex version of CategoricalCrossentropy loss using l2 weight +# regularization. +# """ +# +# def __init__(self, +# reg_lambda: float, +# C: float, +# radius_constant: float, +# from_logits: bool = True, +# label_smoothing: float = 0, +# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, +# name: str = 'binarycrossentropy', +# dtype=tf.float32): +# """ +# Args: +# reg_lambda: Weight regularization constant +# C: Penalty parameter C of the loss term +# radius_constant: constant defining the length of the radius +# reduction: reduction type to use. See super class +# label_smoothing: amount of smoothing to perform on labels +# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) +# name: Name of the loss instance +# dtype: tf datatype to use for tensor conversions. +# """ +# if reg_lambda <= 0: +# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) +# if C <= 0: +# raise ValueError('c: {0}, should be >= 0'.format(C)) +# if radius_constant <= 0: +# raise ValueError('radius_constant: {0}, should be >= 0'.format( +# radius_constant +# )) +# +# self.C = C +# self.dtype = dtype +# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) +# super(StrongConvexHuber, self).__init__(reduction=reduction, +# name=name, +# from_logits=from_logits, +# label_smoothing=label_smoothing, +# ) +# self.radius_constant = radius_constant +# +# def call(self, y_true, y_pred): +# """Compute loss +# +# Args: +# y_true: Ground truth values. +# y_pred: The predicted values. +# +# Returns: +# Loss values per sample. +# """ +# loss = super() +# loss = loss * self.C +# return loss +# +# def radius(self): +# """See super class. +# """ +# return self.radius_constant / self.reg_lambda +# +# def gamma(self): +# """See super class. +# """ +# return self.reg_lambda +# +# def beta(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda +# +# def lipchitz_constant(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda * self.radius() +# +# def kernel_regularizer(self): +# """ +# l2 loss using reg_lambda as the l2 term (as desired). Required for +# this loss function to be strongly convex. +# :return: +# """ +# return L1L2(l2=self.reg_lambda) + diff --git a/privacy/bolton/models.py b/privacy/bolton/models.py index 0a2efc0..06d1c4b 100644 --- a/privacy/bolton/models.py +++ b/privacy/bolton/models.py @@ -170,7 +170,6 @@ class BoltonModel(Model): self.layers, class_weight_, data_size, - self.n_outputs, batch_size_, ) as _: out = super(BoltonModel, self).fit(x=x, @@ -223,7 +222,6 @@ class BoltonModel(Model): self.layers, class_weight, data_size, - self.n_outputs, batch_size ) as _: out = super(BoltonModel, self).fit_generator( diff --git a/privacy/bolton/optimizers.py b/privacy/bolton/optimizers.py index 726ec4f..28c1735 100644 --- a/privacy/bolton/optimizers.py +++ b/privacy/bolton/optimizers.py @@ -137,7 +137,6 @@ class Bolton(optimizer_v2.OptimizerV2): 'class_weights', 'input_dim', 'n_samples', - 'n_outputs', 'layers', 'batch_size', '_is_init' @@ -166,6 +165,9 @@ class Bolton(optimizer_v2.OptimizerV2): Returns: """ + if not self._is_init: + raise Exception('This method must be called from within the optimizer\'s ' + 'context.') radius = self.loss.radius() for layer in self.layers: weight_norm = tf.norm(layer.kernel, axis=0) @@ -323,7 +325,6 @@ class Bolton(optimizer_v2.OptimizerV2): layers: list, class_weights, n_samples, - n_outputs, batch_size ): """Entry point from context. Accepts required values for bolton method and @@ -338,7 +339,6 @@ class Bolton(optimizer_v2.OptimizerV2): class_weights: class_weights used, which may either be a scalar or 1D tensor with dim == n_classes. n_samples number of rows/individual samples in the training set - n_outputs: number of output classes batch_size: batch size used. """ if epsilon <= 0: @@ -352,20 +352,11 @@ class Bolton(optimizer_v2.OptimizerV2): self.learning_rate.initialize(self.loss.beta(class_weights), self.loss.gamma() ) - self.epsilon = _ops.convert_to_tensor_v2(epsilon, dtype=self.dtype) - self.class_weights = _ops.convert_to_tensor_v2(class_weights, - dtype=self.dtype - ) - self.n_samples = _ops.convert_to_tensor_v2(n_samples, - dtype=self.dtype - ) - self.n_outputs = _ops.convert_to_tensor_v2(n_outputs, - dtype=self.dtype - ) + self.epsilon = tf.constant(epsilon, dtype=self.dtype) + self.class_weights = tf.constant(class_weights, dtype=self.dtype) + self.n_samples = tf.constant(n_samples, dtype=self.dtype) self.layers = layers - self.batch_size = _ops.convert_to_tensor_v2(batch_size, - dtype=self.dtype - ) + self.batch_size = tf.constant(batch_size, dtype=self.dtype) return self def __exit__(self, *args): @@ -397,6 +388,5 @@ class Bolton(optimizer_v2.OptimizerV2): self.class_weights = None self.n_samples = None self.input_dim = None - self.n_outputs = None self.layers = None self._is_init = False diff --git a/privacy/bolton/optimizers_test.py b/privacy/bolton/optimizers_test.py index 0a9f9cc..1d0fbfb 100644 --- a/privacy/bolton/optimizers_test.py +++ b/privacy/bolton/optimizers_test.py @@ -314,7 +314,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): model.layers[0].kernel = \ model.layers[0].kernel_initializer((model.layer_input_shape[0], model.n_outputs)) - with bolton(noise, epsilon, model.layers, class_weights, 1, 1, 1) as _: + with bolton(noise, epsilon, model.layers, class_weights, 1, 1) as _: pass return _ops.convert_to_tensor_v2(bolton.epsilon, dtype=tf.float32) epsilon = test_run() @@ -349,7 +349,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): model.layers[0].kernel = \ model.layers[0].kernel_initializer((model.layer_input_shape[0], model.n_outputs)) - with bolton(noise, epsilon, model.layers, 1, 1, 1, 1) as _: + with bolton(noise, epsilon, model.layers, 1, 1, 1) as _: pass with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method test_run(noise, epsilon) diff --git a/tutorials/bolton_tutorial.ipynb b/tutorials/bolton_tutorial.ipynb new file mode 100644 index 0000000..b60e612 --- /dev/null +++ b/tutorials/bolton_tutorial.ipynb @@ -0,0 +1,432 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "pycharm": { + "is_executing": false + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "import tensorflow as tf\n", + "from privacy.bolton import losses\n", + "from privacy.bolton import models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we will create a binary classification dataset with a single output dimension.\n", + "The samples for each label are repeated datapoints at different points in space." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "pycharm": { + "is_executing": false, + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(20, 2) (20, 1)\n" + ] + } + ], + "source": [ + "# Parameters for dataset\n", + "n_samples = 10\n", + "input_dim = 2\n", + "n_outputs = 1\n", + "# Create binary classification dataset:\n", + "x_stack = [tf.constant(-1, tf.float32, (n_samples, input_dim)), \n", + " tf.constant(1, tf.float32, (n_samples, input_dim))]\n", + "y_stack = [tf.constant(0, tf.float32, (n_samples, 1)),\n", + " tf.constant(1, tf.float32, (n_samples, 1))]\n", + "x, y = tf.concat(x_stack, 0), tf.concat(y_stack, 0)\n", + "print(x.shape, y.shape)\n", + "generator = tf.data.Dataset.from_tensor_slices((x, y))\n", + "generator = generator.batch(10)\n", + "generator = generator.shuffle(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we will explore using the pre-built BoltonModel, which is a thin wrapper around a Keras Model using a single-layer neural network. It automatically uses the Bolton Optimizer which encompasses all the logic required for the Bolton Differential Privacy method.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "bolt = models.BoltonModel(n_outputs) # tell the model how many outputs we have." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we will pick our optimizer and Strongly Convex Loss function. The loss must extend from StrongConvexMixin and implement the associated methods. Some existing loss functions are pre-implemented in bolton.loss" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "optimizer = tf.optimizers.SGD()\n", + "reg_lambda = 1\n", + "C = 1\n", + "radius_constant = 1\n", + "loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For simplicity, we pick all parameters of the StrongConvexBinaryCrossentropy to be 1; these are all tunable and their impact can be read in losses.StrongConvexBinaryCrossentropy. We then compile the model with the chosen optimizer and loss, which will automatically wrap the chosen optimizer with the Bolton Optimizer, ensuring the required components function as required for privacy guarantees." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "bolt.compile(optimizer, loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To fit the model, the optimizer will require additional information about the dataset and model. These parameters are:\n", + "1. the class_weights used\n", + "2. the number of samples in the dataset\n", + "3. the batch size\n", + "which the model will try to infer, if possible. If not, you will be required to pass these explicitly to the fit method.\n", + "As well, there are two privacy parameters than can be altered: \n", + "1. epsilon, a float\n", + "2. noise_distribution, a valid string indicating the distriution to use (must be implemented)\n", + "\n", + "The BoltonModel offers a helper method, .calculate_class_weight to aid in class_weight calculation." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Logging before flag parsing goes to stderr.\n", + "W0619 11:00:32.392859 4467058112 deprecation.py:323] From /Users/christopherchoo/PycharmProjects/privacy/venv/lib/python3.6/site-packages/tensorflow/python/ops/nn_impl.py:182: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Use tf.where in 2.0, which has the same broadcast rule as np.where\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 20 samples\n", + "Epoch 1/2\n", + "20/20 [==============================] - 0s 4ms/sample - loss: 0.8146\n", + "Epoch 2/2\n", + "20/20 [==============================] - 0s 94us/sample - loss: 0.5699\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# required parameters\n", + "class_weight = None # default, use .calculate_class_weight to specify other values\n", + "batch_size = None # default, if it cannot be inferred, specify this\n", + "n_samples = None # default, if it cannot be iferred, specify this\n", + "# privacy parameters\n", + "epsilon = 2\n", + "noise_distribution = 'laplace'\n", + "\n", + "bolt.fit(x, \n", + " y, \n", + " epsilon=epsilon, \n", + " class_weight=class_weight, \n", + " batch_size=batch_size, \n", + " n_samples=n_samples,\n", + " noise_distribution=noise_distribution,\n", + " epochs=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We may also train a generator object, or try different optimizers and loss functions. Below, we will see that we must pass the number of samples as the fit method is unable to infer it for a generator." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer2 = tf.optimizers.Adam()\n", + "bolt.compile(optimizer2, loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Could not infer the number of samples. Please pass this in using n_samples.\n" + ] + } + ], + "source": [ + "# required parameters\n", + "class_weight = None # default, use .calculate_class_weight to specify other values\n", + "batch_size = None # default, if it cannot be inferred, specify this\n", + "n_samples = None # default, if it cannot be iferred, specify this\n", + "# privacy parameters\n", + "epsilon = 2\n", + "noise_distribution = 'laplace'\n", + "try:\n", + " bolt.fit(generator,\n", + " epsilon=epsilon, \n", + " class_weight=class_weight, \n", + " batch_size=batch_size, \n", + " n_samples=n_samples,\n", + " noise_distribution=noise_distribution,\n", + " verbose=0\n", + " )\n", + "except ValueError as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now, re running with the parameter set." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n_samples = 20\n", + "bolt.fit(generator,\n", + " epsilon=epsilon, \n", + " class_weight=class_weight, \n", + " batch_size=batch_size, \n", + " n_samples=n_samples,\n", + " noise_distribution=noise_distribution,\n", + " verbose=0\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You don't have to use the bolton model to use the Bolton method. There are only a few requirements:\n", + "1. make sure any requirements from the loss are implemented in the model.\n", + "2. instantiate the optimizer and use it as a context around your fit operation." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from privacy.bolton.optimizers import Bolton" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here, we create our own model and setup the Bolton optimizer." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "class TestModel(tf.keras.Model):\n", + " def __init__(self, reg_layer, n_outputs=1):\n", + " super(TestModel, self).__init__(name='test')\n", + " self.output_layer = tf.keras.layers.Dense(n_outputs,\n", + " kernel_regularizer=reg_layer\n", + " )\n", + " \n", + " def call(self, inputs, training=False):\n", + " return self.output_layer(inputs)\n", + "\n", + "optimizer = tf.optimizers.SGD()\n", + "loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)\n", + "optimizer = Bolton(optimizer, loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we instantiate our model and check for 1. Since our loss requires L2 regularization over the kernel, we will pass it to the model." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "n_outputs = 1 # parameter for model and optimizer context.\n", + "test_model = TestModel(loss.kernel_regularizer(), n_outputs)\n", + "test_model.compile(optimizer, loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We comply with 2., and use the Bolton Optimizer as a context around the fit method." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 20 samples\n", + "Epoch 1/2\n", + "20/20 [==============================] - 0s 3ms/sample - loss: 0.9096\n", + "Epoch 2/2\n", + "20/20 [==============================] - 0s 430us/sample - loss: 0.5275\n" + ] + } + ], + "source": [ + "# parameters for context\n", + "noise_distribution = 'laplace'\n", + "epsilon = 2\n", + "class_weights = 1 # Previosuly, the fit method auto-detected the class_weights.\n", + "# Here, we need to pass the class_weights explicitly. 1 is the equivalent of None.\n", + "n_samples = 20\n", + "batch_size = 5\n", + "\n", + "with optimizer(\n", + " noise_distribution=noise_distribution,\n", + " epsilon=epsilon,\n", + " layers=test_model.layers,\n", + " class_weights=class_weights, \n", + " n_samples=n_samples,\n", + " batch_size=batch_size\n", + ") as _:\n", + " test_model.fit(x, y, batch_size=batch_size, epochs=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} From b120d9c5d84fe03d88ca78ce297d748f6ef38cf7 Mon Sep 17 00:00:00 2001 From: Christopher Choquette Choo Date: Wed, 19 Jun 2019 11:14:02 -0400 Subject: [PATCH 8/9] Changes for pylint. --- privacy/bolton/losses.py | 280 ------------------------------ privacy/bolton/models_test.py | 4 +- privacy/bolton/optimizers.py | 1 - privacy/bolton/optimizers_test.py | 26 +-- 4 files changed, 8 insertions(+), 303 deletions(-) diff --git a/privacy/bolton/losses.py b/privacy/bolton/losses.py index 6a54576..a99187b 100644 --- a/privacy/bolton/losses.py +++ b/privacy/bolton/losses.py @@ -317,283 +317,3 @@ class StrongConvexBinaryCrossentropy( :return: """ return L1L2(l2=self.reg_lambda/2) - - -# class StrongConvexSparseCategoricalCrossentropy( -# losses.CategoricalCrossentropy, -# StrongConvexMixin -# ): -# """ -# Strong Convex version of CategoricalCrossentropy loss using l2 weight -# regularization. -# """ -# -# def __init__(self, -# reg_lambda: float, -# C: float, -# radius_constant: float, -# from_logits: bool = True, -# label_smoothing: float = 0, -# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, -# name: str = 'binarycrossentropy', -# dtype=tf.float32): -# """ -# Args: -# reg_lambda: Weight regularization constant -# C: Penalty parameter C of the loss term -# radius_constant: constant defining the length of the radius -# reduction: reduction type to use. See super class -# label_smoothing: amount of smoothing to perform on labels -# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) -# name: Name of the loss instance -# dtype: tf datatype to use for tensor conversions. -# """ -# if reg_lambda <= 0: -# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) -# if C <= 0: -# raise ValueError('c: {0}, should be >= 0'.format(C)) -# if radius_constant <= 0: -# raise ValueError('radius_constant: {0}, should be >= 0'.format( -# radius_constant -# )) -# -# self.C = C -# self.dtype = dtype -# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) -# super(StrongConvexSparseCategoricalCrossentropy, self).__init__( -# reduction=reduction, -# name=name, -# from_logits=from_logits, -# label_smoothing=label_smoothing, -# ) -# self.radius_constant = radius_constant -# -# def call(self, y_true, y_pred): -# """Compute loss -# -# Args: -# y_true: Ground truth values. -# y_pred: The predicted values. -# -# Returns: -# Loss values per sample. -# """ -# loss = super() -# loss = loss * self.C -# return loss -# -# def radius(self): -# """See super class. -# """ -# return self.radius_constant / self.reg_lambda -# -# def gamma(self): -# """See super class. -# """ -# return self.reg_lambda -# -# def beta(self, class_weight): -# """See super class. -# """ -# max_class_weight = self.max_class_weight(class_weight, self.dtype) -# return self.C * max_class_weight + self.reg_lambda -# -# def lipchitz_constant(self, class_weight): -# """See super class. -# """ -# max_class_weight = self.max_class_weight(class_weight, self.dtype) -# return self.C * max_class_weight + self.reg_lambda * self.radius() -# -# def kernel_regularizer(self): -# """ -# l2 loss using reg_lambda as the l2 term (as desired). Required for -# this loss function to be strongly convex. -# :return: -# """ -# return L1L2(l2=self.reg_lambda) -# -# class StrongConvexSparseCategoricalCrossentropy( -# losses.SparseCategoricalCrossentropy, -# StrongConvexMixin -# ): -# """ -# Strong Convex version of SparseCategoricalCrossentropy loss using l2 weight -# regularization. -# """ -# -# def __init__(self, -# reg_lambda: float, -# C: float, -# radius_constant: float, -# from_logits: bool = True, -# label_smoothing: float = 0, -# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, -# name: str = 'binarycrossentropy', -# dtype=tf.float32): -# """ -# Args: -# reg_lambda: Weight regularization constant -# C: Penalty parameter C of the loss term -# radius_constant: constant defining the length of the radius -# reduction: reduction type to use. See super class -# label_smoothing: amount of smoothing to perform on labels -# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) -# name: Name of the loss instance -# dtype: tf datatype to use for tensor conversions. -# """ -# if reg_lambda <= 0: -# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) -# if C <= 0: -# raise ValueError('c: {0}, should be >= 0'.format(C)) -# if radius_constant <= 0: -# raise ValueError('radius_constant: {0}, should be >= 0'.format( -# radius_constant -# )) -# -# self.C = C -# self.dtype = dtype -# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) -# super(StrongConvexHuber, self).__init__(reduction=reduction, -# name=name, -# from_logits=from_logits, -# label_smoothing=label_smoothing, -# ) -# self.radius_constant = radius_constant -# -# def call(self, y_true, y_pred): -# """Compute loss -# -# Args: -# y_true: Ground truth values. -# y_pred: The predicted values. -# -# Returns: -# Loss values per sample. -# """ -# loss = super() -# loss = loss * self.C -# return loss -# -# def radius(self): -# """See super class. -# """ -# return self.radius_constant / self.reg_lambda -# -# def gamma(self): -# """See super class. -# """ -# return self.reg_lambda -# -# def beta(self, class_weight): -# """See super class. -# """ -# max_class_weight = self.max_class_weight(class_weight, self.dtype) -# return self.C * max_class_weight + self.reg_lambda -# -# def lipchitz_constant(self, class_weight): -# """See super class. -# """ -# max_class_weight = self.max_class_weight(class_weight, self.dtype) -# return self.C * max_class_weight + self.reg_lambda * self.radius() -# -# def kernel_regularizer(self): -# """ -# l2 loss using reg_lambda as the l2 term (as desired). Required for -# this loss function to be strongly convex. -# :return: -# """ -# return L1L2(l2=self.reg_lambda) -# -# -# class StrongConvexCategoricalCrossentropy( -# losses.CategoricalCrossentropy, -# StrongConvexMixin -# ): -# """ -# Strong Convex version of CategoricalCrossentropy loss using l2 weight -# regularization. -# """ -# -# def __init__(self, -# reg_lambda: float, -# C: float, -# radius_constant: float, -# from_logits: bool = True, -# label_smoothing: float = 0, -# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, -# name: str = 'binarycrossentropy', -# dtype=tf.float32): -# """ -# Args: -# reg_lambda: Weight regularization constant -# C: Penalty parameter C of the loss term -# radius_constant: constant defining the length of the radius -# reduction: reduction type to use. See super class -# label_smoothing: amount of smoothing to perform on labels -# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) -# name: Name of the loss instance -# dtype: tf datatype to use for tensor conversions. -# """ -# if reg_lambda <= 0: -# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) -# if C <= 0: -# raise ValueError('c: {0}, should be >= 0'.format(C)) -# if radius_constant <= 0: -# raise ValueError('radius_constant: {0}, should be >= 0'.format( -# radius_constant -# )) -# -# self.C = C -# self.dtype = dtype -# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) -# super(StrongConvexHuber, self).__init__(reduction=reduction, -# name=name, -# from_logits=from_logits, -# label_smoothing=label_smoothing, -# ) -# self.radius_constant = radius_constant -# -# def call(self, y_true, y_pred): -# """Compute loss -# -# Args: -# y_true: Ground truth values. -# y_pred: The predicted values. -# -# Returns: -# Loss values per sample. -# """ -# loss = super() -# loss = loss * self.C -# return loss -# -# def radius(self): -# """See super class. -# """ -# return self.radius_constant / self.reg_lambda -# -# def gamma(self): -# """See super class. -# """ -# return self.reg_lambda -# -# def beta(self, class_weight): -# """See super class. -# """ -# max_class_weight = self.max_class_weight(class_weight, self.dtype) -# return self.C * max_class_weight + self.reg_lambda -# -# def lipchitz_constant(self, class_weight): -# """See super class. -# """ -# max_class_weight = self.max_class_weight(class_weight, self.dtype) -# return self.C * max_class_weight + self.reg_lambda * self.radius() -# -# def kernel_regularizer(self): -# """ -# l2 loss using reg_lambda as the l2 term (as desired). Required for -# this loss function to be strongly convex. -# :return: -# """ -# return L1L2(l2=self.reg_lambda) - diff --git a/privacy/bolton/models_test.py b/privacy/bolton/models_test.py index 05119d3..63954cc 100644 --- a/privacy/bolton/models_test.py +++ b/privacy/bolton/models_test.py @@ -53,7 +53,7 @@ class TestLoss(losses.Loss, StrongConvexMixin): """ return _ops.convert_to_tensor_v2(1, dtype=tf.float32) - def beta(self, class_weight): + def beta(self, class_weight): # pylint: disable=unused-argument """Beta smoothess Args: @@ -64,7 +64,7 @@ class TestLoss(losses.Loss, StrongConvexMixin): """ return _ops.convert_to_tensor_v2(1, dtype=tf.float32) - def lipchitz_constant(self, class_weight): + def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument """ L lipchitz continuous Args: diff --git a/privacy/bolton/optimizers.py b/privacy/bolton/optimizers.py index 28c1735..ec7a7e5 100644 --- a/privacy/bolton/optimizers.py +++ b/privacy/bolton/optimizers.py @@ -20,7 +20,6 @@ from __future__ import print_function import tensorflow as tf from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import math_ops -from tensorflow.python import ops as _ops from privacy.bolton.losses import StrongConvexMixin _accepted_distributions = ['laplace'] # implemented distributions for noising diff --git a/privacy/bolton/optimizers_test.py b/privacy/bolton/optimizers_test.py index 1d0fbfb..6a499fc 100644 --- a/privacy/bolton/optimizers_test.py +++ b/privacy/bolton/optimizers_test.py @@ -25,7 +25,6 @@ from tensorflow.python.keras.regularizers import L1L2 from tensorflow.python.keras.initializers import constant from tensorflow.python.keras import losses from tensorflow.python.keras.models import Model -from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import test_util from tensorflow.python import ops as _ops from absl.testing import parameterized @@ -33,7 +32,6 @@ from privacy.bolton.losses import StrongConvexMixin from privacy.bolton import optimizers as opt - class TestModel(Model): """ Bolton episilon-delta model @@ -69,18 +67,6 @@ class TestModel(Model): ) - # def call(self, inputs): - # """Forward pass of network - # - # Args: - # inputs: inputs to neural network - # - # Returns: - # - # """ - # return self.output_layer(inputs) - - class TestLoss(losses.Loss, StrongConvexMixin): """Test loss function for testing Bolton model""" def __init__(self, reg_lambda, C, radius_constant, name='test'): @@ -105,7 +91,7 @@ class TestLoss(losses.Loss, StrongConvexMixin): """ return _ops.convert_to_tensor_v2(1, dtype=tf.float32) - def beta(self, class_weight): + def beta(self, class_weight): # pylint: disable=unused-argument """Beta smoothess Args: @@ -116,7 +102,7 @@ class TestLoss(losses.Loss, StrongConvexMixin): """ return _ops.convert_to_tensor_v2(1, dtype=tf.float32) - def lipchitz_constant(self, class_weight): + def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument """ L lipchitz continuous Args: @@ -217,7 +203,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): model.layers[0].kernel = \ model.layers[0].kernel_initializer((model.layer_input_shape[0], model.n_outputs)) - bolton._is_init = True + bolton._is_init = True # pylint: disable=protected-access bolton.layers = model.layers bolton.epsilon = 2 bolton.noise_distribution = 'laplace' @@ -279,7 +265,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): model.layers[0].kernel = \ model.layers[0].kernel_initializer((model.layer_input_shape[0], model.n_outputs)) - bolton._is_init = True + bolton._is_init = True # pylint: disable=protected-access bolton.layers = model.layers bolton.epsilon = 2 bolton.noise_distribution = 'laplace' @@ -431,7 +417,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): model.layers[0].kernel = \ model.layers[0].kernel_initializer((model.layer_input_shape[0], model.n_outputs)) - bolton._is_init = True + bolton._is_init = True # pylint: disable=protected-access bolton.layers = model.layers bolton.epsilon = 2 bolton.noise_distribution = 'laplace' @@ -467,7 +453,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): model.layers[0].kernel = \ model.layers[0].kernel_initializer((model.layer_input_shape[0], model.n_outputs)) - bolton._is_init = True + bolton._is_init = True # pylint: disable=protected-access bolton.noise_distribution = 'laplace' bolton.epsilon = 1 bolton.layers = model.layers From 3080b654b570bbad6487653fa2b02b108a867d83 Mon Sep 17 00:00:00 2001 From: Christopher Choquette Choo Date: Wed, 19 Jun 2019 11:18:42 -0400 Subject: [PATCH 9/9] Minor changes to function arguments --- privacy/bolton/models.py | 2 +- tutorials/bolton_tutorial.ipynb | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/privacy/bolton/models.py b/privacy/bolton/models.py index 06d1c4b..7503157 100644 --- a/privacy/bolton/models.py +++ b/privacy/bolton/models.py @@ -60,7 +60,7 @@ class BoltonModel(Model): self._layers_instantiated = False self._dtype = dtype - def call(self, inputs, training=False): # pylint: disable=arguments-differ + def call(self, inputs): # pylint: disable=arguments-differ """Forward pass of network Args: diff --git a/tutorials/bolton_tutorial.ipynb b/tutorials/bolton_tutorial.ipynb index b60e612..f682592 100644 --- a/tutorials/bolton_tutorial.ipynb +++ b/tutorials/bolton_tutorial.ipynb @@ -321,7 +321,7 @@ " kernel_regularizer=reg_layer\n", " )\n", " \n", - " def call(self, inputs, training=False):\n", + " def call(self, inputs):\n", " return self.output_layer(inputs)\n", "\n", "optimizer = tf.optimizers.SGD()\n", @@ -420,13 +420,13 @@ "pycharm": { "stem_cell": { "cell_type": "raw", + "source": [], "metadata": { "collapsed": false - }, - "source": [] + } } } }, "nbformat": 4, "nbformat_minor": 1 -} +} \ No newline at end of file