Merge pull request #70 from georgianpartners:master

PiperOrigin-RevId: 265056745
This commit is contained in:
A. Unique TensorFlower 2019-08-23 08:06:11 -07:00
commit 0e84af1e69
6 changed files with 55 additions and 41 deletions

View file

@ -42,6 +42,16 @@ delta-epsilon privacy in machine learning, some of which can be explored here:
https://medium.com/apache-mxnet/epsilon-differential-privacy-for-machine-learning-using-mxnet-a4270fe3865e
https://arxiv.org/pdf/1811.04911.pdf
## Stability
As we are pegged on tensorflow2.0, this package may encounter stability
issues in the ongoing development of tensorflow2.0.
This sub-package is currently stable for 2.0.0a0, 2.0.0b0, and 2.0.0.b1 If you
would like to use this subpackage, please do use one of these versions as we
cannot guarantee it will work for all latest releases. If you do find issues,
feel free to raise an issue to the contributors listed below.
## Contacts
In addition to the maintainers of tensorflow/privacy listed in the root

View file

@ -217,9 +217,11 @@ class BoltOnModel(Model): # pylint: disable=abstract-method
elif hasattr(generator, '__len__'):
data_size = len(generator)
else:
data_size = None
batch_size = self._validate_or_infer_batch_size(None,
steps_per_epoch,
raise ValueError('The number of samples could not be determined. '
'Please make sure that if you are using a generator'
'to call this method directly with n_samples kwarg '
'passed.')
batch_size = self._validate_or_infer_batch_size(None, steps_per_epoch,
generator)
if batch_size is None:
batch_size = 32

View file

@ -294,6 +294,12 @@ def _do_fit(n_samples,
# x = x.batch(batch_size)
x = x.shuffle(n_samples//2)
batch_size = None
if reset_n_samples:
n_samples = None
clf.fit_generator(x,
n_samples=n_samples,
noise_distribution=distribution,
epsilon=epsilon)
else:
x, y = _cat_dataset(
n_samples,
@ -301,15 +307,14 @@ def _do_fit(n_samples,
n_outputs,
batch_size,
generator=generator)
if reset_n_samples:
n_samples = None
clf.fit(x,
y,
batch_size=batch_size,
n_samples=n_samples,
noise_distribution=distribution,
epsilon=epsilon)
if reset_n_samples:
n_samples = None
clf.fit(x,
y,
batch_size=batch_size,
n_samples=n_samples,
noise_distribution=distribution,
epsilon=epsilon)
return clf

View file

@ -129,18 +129,19 @@ class BoltOn(optimizer_v2.OptimizerV2):
if not isinstance(loss, StrongConvexMixin):
raise ValueError('loss function must be a Strongly Convex and therefore '
'extend the StrongConvexMixin.')
self._private_attributes = ['_internal_optimizer',
'dtype',
'noise_distribution',
'epsilon',
'loss',
'class_weights',
'input_dim',
'n_samples',
'layers',
'batch_size',
'_is_init'
]
self._private_attributes = [
'_internal_optimizer',
'dtype',
'noise_distribution',
'epsilon',
'loss',
'class_weights',
'input_dim',
'n_samples',
'layers',
'batch_size',
'_is_init',
]
self._internal_optimizer = optimizer
self.learning_rate = GammaBetaDecreasingStep() # use the BoltOn Learning
# rate scheduler, as required for privacy guarantees. This will still need
@ -250,8 +251,7 @@ class BoltOn(optimizer_v2.OptimizerV2):
"Neither '{0}' nor '{1}' object has attribute '{2}'"
"".format(self.__class__.__name__,
self._internal_optimizer.__class__.__name__,
name
)
name)
)
def __setattr__(self, key, value):
@ -319,8 +319,7 @@ class BoltOn(optimizer_v2.OptimizerV2):
layers,
class_weights,
n_samples,
batch_size
):
batch_size):
"""Accepts required values for bolton method from context entry point.
Stores them on the optimizer for use throughout fitting.
@ -347,8 +346,7 @@ class BoltOn(optimizer_v2.OptimizerV2):
_accepted_distributions))
self.noise_distribution = noise_distribution
self.learning_rate.initialize(self.loss.beta(class_weights),
self.loss.gamma()
)
self.loss.gamma())
self.epsilon = tf.constant(epsilon, dtype=self.dtype)
self.class_weights = tf.constant(class_weights, dtype=self.dtype)
self.n_samples = tf.constant(n_samples, dtype=self.dtype)

View file

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python import ops as _ops
@ -270,7 +271,6 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
result: the expected output after projection.
"""
tf.random.set_seed(1)
@tf.function
def project_fn(r):
loss = TestLoss(1, 1, r)
bolton = opt.BoltOn(TestOptimizer(), loss)
@ -358,7 +358,8 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
{'testcase_name': 'fn: get_noise',
'fn': 'get_noise',
'args': [1, 1],
'err_msg': 'ust be called from within the optimizer\'s context'},
'err_msg': 'This method must be called from within the '
'optimizer\'s context'},
])
def test_not_in_context(self, fn, args, err_msg):
"""Tests that the expected functions raise errors when not in context.
@ -368,7 +369,6 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
args: the arguments for said function
err_msg: expected error message
"""
@tf.function
def test_run(fn, args):
loss = TestLoss(1, 1, 1)
bolton = opt.BoltOn(TestOptimizer(), loss)
@ -462,7 +462,6 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
fn: fn to test
args: arguments to that fn
"""
@tf.function
def test_run(fn, args):
loss = TestLoss(1, 1, 1)
bolton = opt.BoltOn(TestOptimizer(), loss)
@ -577,3 +576,4 @@ class SchedulerTest(keras_parameterized.TestCase):
if __name__ == '__main__':
test.main()
unittest.main()

View file

@ -124,13 +124,12 @@ except ValueError as e:
# And now, re running with the parameter set.
# -------
n_samples = 20
bolt.fit(generator,
epsilon=epsilon,
class_weight=class_weight,
batch_size=batch_size,
n_samples=n_samples,
noise_distribution=noise_distribution,
verbose=0)
bolt.fit_generator(generator,
epsilon=epsilon,
class_weight=class_weight,
n_samples=n_samples,
noise_distribution=noise_distribution,
verbose=0)
# -------
# You don't have to use the BoltOn model to use the BoltOn method.
# There are only a few requirements: