forked from 626_privacy/tensorflow_privacy
Augmenting implementation to handle new tensorflow _validate_or_infer_batch_size implementation.
This commit is contained in:
parent
136200d0c2
commit
5ef3cec26e
2 changed files with 15 additions and 2 deletions
|
@ -157,6 +157,8 @@ class BoltOnModel(Model): # pylint: disable=abstract-method
|
||||||
batch_size_ = self._validate_or_infer_batch_size(batch_size,
|
batch_size_ = self._validate_or_infer_batch_size(batch_size,
|
||||||
steps_per_epoch,
|
steps_per_epoch,
|
||||||
x)
|
x)
|
||||||
|
if batch_size_ is None:
|
||||||
|
batch_size_ = 32
|
||||||
# inferring batch_size to be passed to optimizer. batch_size must remain its
|
# inferring batch_size to be passed to optimizer. batch_size must remain its
|
||||||
# initial value when passed to super().fit()
|
# initial value when passed to super().fit()
|
||||||
if batch_size_ is None:
|
if batch_size_ is None:
|
||||||
|
@ -219,6 +221,8 @@ class BoltOnModel(Model): # pylint: disable=abstract-method
|
||||||
batch_size = self._validate_or_infer_batch_size(None,
|
batch_size = self._validate_or_infer_batch_size(None,
|
||||||
steps_per_epoch,
|
steps_per_epoch,
|
||||||
generator)
|
generator)
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = 32
|
||||||
with self.optimizer(noise_distribution,
|
with self.optimizer(noise_distribution,
|
||||||
epsilon,
|
epsilon,
|
||||||
self.layers,
|
self.layers,
|
||||||
|
|
|
@ -215,7 +215,7 @@ class InitTests(keras_parameterized.TestCase):
|
||||||
clf.compile(optimizer, loss)
|
clf.compile(optimizer, loss)
|
||||||
|
|
||||||
|
|
||||||
def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
|
def _cat_dataset(n_samples, input_dim, n_classes, batch_size, generator=False):
|
||||||
"""Creates a categorically encoded dataset.
|
"""Creates a categorically encoded dataset.
|
||||||
|
|
||||||
Creates a categorically encoded dataset (y is categorical).
|
Creates a categorically encoded dataset (y is categorical).
|
||||||
|
@ -228,6 +228,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
|
||||||
input_dim: input dimensionality
|
input_dim: input dimensionality
|
||||||
n_classes: output dimensionality
|
n_classes: output dimensionality
|
||||||
generator: False for array, True for generator
|
generator: False for array, True for generator
|
||||||
|
batch_size: The desired batch_size.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
X as (n_samples, input_dim), Y as (n_samples, n_outputs)
|
X as (n_samples, input_dim), Y as (n_samples, n_outputs)
|
||||||
|
@ -246,6 +247,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
|
||||||
dataset = tf.data.Dataset.from_tensor_slices(
|
dataset = tf.data.Dataset.from_tensor_slices(
|
||||||
(x_set, y_set)
|
(x_set, y_set)
|
||||||
)
|
)
|
||||||
|
dataset = dataset.batch(batch_size=batch_size)
|
||||||
return dataset
|
return dataset
|
||||||
return x_set, y_set
|
return x_set, y_set
|
||||||
|
|
||||||
|
@ -285,6 +287,7 @@ def _do_fit(n_samples,
|
||||||
n_samples,
|
n_samples,
|
||||||
input_dim,
|
input_dim,
|
||||||
n_outputs,
|
n_outputs,
|
||||||
|
batch_size,
|
||||||
generator=generator
|
generator=generator
|
||||||
)
|
)
|
||||||
y = None
|
y = None
|
||||||
|
@ -292,7 +295,12 @@ def _do_fit(n_samples,
|
||||||
x = x.shuffle(n_samples//2)
|
x = x.shuffle(n_samples//2)
|
||||||
batch_size = None
|
batch_size = None
|
||||||
else:
|
else:
|
||||||
x, y = _cat_dataset(n_samples, input_dim, n_outputs, generator=generator)
|
x, y = _cat_dataset(
|
||||||
|
n_samples,
|
||||||
|
input_dim,
|
||||||
|
n_outputs,
|
||||||
|
batch_size,
|
||||||
|
generator=generator)
|
||||||
if reset_n_samples:
|
if reset_n_samples:
|
||||||
n_samples = None
|
n_samples = None
|
||||||
|
|
||||||
|
@ -377,6 +385,7 @@ class FitTests(keras_parameterized.TestCase):
|
||||||
n_samples,
|
n_samples,
|
||||||
input_dim,
|
input_dim,
|
||||||
n_classes,
|
n_classes,
|
||||||
|
batch_size,
|
||||||
generator=generator
|
generator=generator
|
||||||
)
|
)
|
||||||
x = x.batch(batch_size)
|
x = x.batch(batch_size)
|
||||||
|
|
Loading…
Reference in a new issue