forked from 626_privacy/tensorflow_privacy
Stable version for tf2.0a0, b0.
This commit is contained in:
parent
7d885640ec
commit
18ce9c2335
6 changed files with 39 additions and 35 deletions
|
@ -44,11 +44,12 @@ https://arxiv.org/pdf/1811.04911.pdf
|
|||
|
||||
## Stability
|
||||
|
||||
As we are pegged on tensorflow2.0.0, this package may encounter stability
|
||||
issues in the ongoing development of this package.
|
||||
As we are pegged on tensorflow2.0, this package may encounter stability
|
||||
issues in the ongoing development of tensorflow2.0.
|
||||
|
||||
We are aware of issues in model fitting using the BoltOnModel and are actively
|
||||
working towards solving these issues.
|
||||
This sub-package is currently stable for 2.0.0a0 and 2.0.0b0. We are aware of
|
||||
issues in model fitting using the BoltOnModel in beta1, the latest release,
|
||||
and are actively working towards solving these issues.
|
||||
|
||||
## Contacts
|
||||
|
||||
|
|
|
@ -217,7 +217,10 @@ class BoltOnModel(Model): # pylint: disable=abstract-method
|
|||
elif hasattr(generator, '__len__'):
|
||||
data_size = len(generator)
|
||||
else:
|
||||
data_size = None
|
||||
raise ValueError("The number of samples could not be determined. "
|
||||
"Please make sure that if you are using a generator"
|
||||
"to call this method directly with n_samples kwarg "
|
||||
"passed.")
|
||||
batch_size = self._validate_or_infer_batch_size(None,
|
||||
steps_per_epoch,
|
||||
generator)
|
||||
|
|
|
@ -227,8 +227,8 @@ def _cat_dataset(n_samples, input_dim, n_classes, batch_size, generator=False):
|
|||
n_samples: number of rows
|
||||
input_dim: input dimensionality
|
||||
n_classes: output dimensionality
|
||||
batch_size: The desired batch_size
|
||||
generator: False for array, True for generator
|
||||
batch_size: The desired batch_size.
|
||||
|
||||
Returns:
|
||||
X as (n_samples, input_dim), Y as (n_samples, n_outputs)
|
||||
|
@ -294,6 +294,12 @@ def _do_fit(n_samples,
|
|||
# x = x.batch(batch_size)
|
||||
x = x.shuffle(n_samples//2)
|
||||
batch_size = None
|
||||
if reset_n_samples:
|
||||
n_samples = None
|
||||
clf.fit_generator(x,
|
||||
n_samples=n_samples,
|
||||
noise_distribution=distribution,
|
||||
epsilon=epsilon)
|
||||
else:
|
||||
x, y = _cat_dataset(
|
||||
n_samples,
|
||||
|
@ -303,7 +309,6 @@ def _do_fit(n_samples,
|
|||
generator=generator)
|
||||
if reset_n_samples:
|
||||
n_samples = None
|
||||
|
||||
clf.fit(x,
|
||||
y,
|
||||
batch_size=batch_size,
|
||||
|
|
|
@ -139,7 +139,7 @@ class BoltOn(optimizer_v2.OptimizerV2):
|
|||
'n_samples',
|
||||
'layers',
|
||||
'batch_size',
|
||||
'_is_init'
|
||||
'_is_init',
|
||||
]
|
||||
self._internal_optimizer = optimizer
|
||||
self.learning_rate = GammaBetaDecreasingStep() # use the BoltOn Learning
|
||||
|
@ -250,8 +250,7 @@ class BoltOn(optimizer_v2.OptimizerV2):
|
|||
"Neither '{0}' nor '{1}' object has attribute '{2}'"
|
||||
"".format(self.__class__.__name__,
|
||||
self._internal_optimizer.__class__.__name__,
|
||||
name
|
||||
)
|
||||
name)
|
||||
)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
|
@ -319,8 +318,7 @@ class BoltOn(optimizer_v2.OptimizerV2):
|
|||
layers,
|
||||
class_weights,
|
||||
n_samples,
|
||||
batch_size
|
||||
):
|
||||
batch_size):
|
||||
"""Accepts required values for bolton method from context entry point.
|
||||
|
||||
Stores them on the optimizer for use throughout fitting.
|
||||
|
@ -347,8 +345,7 @@ class BoltOn(optimizer_v2.OptimizerV2):
|
|||
_accepted_distributions))
|
||||
self.noise_distribution = noise_distribution
|
||||
self.learning_rate.initialize(self.loss.beta(class_weights),
|
||||
self.loss.gamma()
|
||||
)
|
||||
self.loss.gamma())
|
||||
self.epsilon = tf.constant(epsilon, dtype=self.dtype)
|
||||
self.class_weights = tf.constant(class_weights, dtype=self.dtype)
|
||||
self.n_samples = tf.constant(n_samples, dtype=self.dtype)
|
||||
|
|
|
@ -199,7 +199,6 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
'result': None,
|
||||
'test_attr': ''},
|
||||
])
|
||||
|
||||
def test_fn(self, fn, args, result, test_attr):
|
||||
"""test that a fn of BoltOn optimizer is working as expected.
|
||||
|
||||
|
@ -270,7 +269,6 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
result: the expected output after projection.
|
||||
"""
|
||||
tf.random.set_seed(1)
|
||||
@tf.function
|
||||
def project_fn(r):
|
||||
loss = TestLoss(1, 1, r)
|
||||
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||
|
@ -358,7 +356,8 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
{'testcase_name': 'fn: get_noise',
|
||||
'fn': 'get_noise',
|
||||
'args': [1, 1],
|
||||
'err_msg': 'ust be called from within the optimizer\'s context'},
|
||||
'err_msg': 'This method must be called from within the '
|
||||
'optimizer\'s context'},
|
||||
])
|
||||
def test_not_in_context(self, fn, args, err_msg):
|
||||
"""Tests that the expected functions raise errors when not in context.
|
||||
|
@ -368,7 +367,6 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
args: the arguments for said function
|
||||
err_msg: expected error message
|
||||
"""
|
||||
@tf.function
|
||||
def test_run(fn, args):
|
||||
loss = TestLoss(1, 1, 1)
|
||||
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||
|
@ -462,7 +460,6 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
fn: fn to test
|
||||
args: arguments to that fn
|
||||
"""
|
||||
@tf.function
|
||||
def test_run(fn, args):
|
||||
loss = TestLoss(1, 1, 1)
|
||||
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||
|
@ -577,3 +574,5 @@ class SchedulerTest(keras_parameterized.TestCase):
|
|||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
import unittest
|
||||
unittest.main()
|
||||
|
|
|
@ -124,10 +124,9 @@ except ValueError as e:
|
|||
# And now, re running with the parameter set.
|
||||
# -------
|
||||
n_samples = 20
|
||||
bolt.fit(generator,
|
||||
bolt.fit_generator(generator,
|
||||
epsilon=epsilon,
|
||||
class_weight=class_weight,
|
||||
batch_size=batch_size,
|
||||
n_samples=n_samples,
|
||||
noise_distribution=noise_distribution,
|
||||
verbose=0)
|
||||
|
|
Loading…
Reference in a new issue