forked from 626_privacy/tensorflow_privacy
Augmenting implementation to handle new tensorflow _validate_or_infer_batch_size implementation.
This commit is contained in:
parent
136200d0c2
commit
5ef3cec26e
2 changed files with 15 additions and 2 deletions
|
@ -157,6 +157,8 @@ class BoltOnModel(Model): # pylint: disable=abstract-method
|
|||
batch_size_ = self._validate_or_infer_batch_size(batch_size,
|
||||
steps_per_epoch,
|
||||
x)
|
||||
if batch_size_ is None:
|
||||
batch_size_ = 32
|
||||
# inferring batch_size to be passed to optimizer. batch_size must remain its
|
||||
# initial value when passed to super().fit()
|
||||
if batch_size_ is None:
|
||||
|
@ -219,6 +221,8 @@ class BoltOnModel(Model): # pylint: disable=abstract-method
|
|||
batch_size = self._validate_or_infer_batch_size(None,
|
||||
steps_per_epoch,
|
||||
generator)
|
||||
if batch_size is None:
|
||||
batch_size = 32
|
||||
with self.optimizer(noise_distribution,
|
||||
epsilon,
|
||||
self.layers,
|
||||
|
|
|
@ -215,7 +215,7 @@ class InitTests(keras_parameterized.TestCase):
|
|||
clf.compile(optimizer, loss)
|
||||
|
||||
|
||||
def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
|
||||
def _cat_dataset(n_samples, input_dim, n_classes, batch_size, generator=False):
|
||||
"""Creates a categorically encoded dataset.
|
||||
|
||||
Creates a categorically encoded dataset (y is categorical).
|
||||
|
@ -228,6 +228,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
|
|||
input_dim: input dimensionality
|
||||
n_classes: output dimensionality
|
||||
generator: False for array, True for generator
|
||||
batch_size: The desired batch_size.
|
||||
|
||||
Returns:
|
||||
X as (n_samples, input_dim), Y as (n_samples, n_outputs)
|
||||
|
@ -246,6 +247,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
|
|||
dataset = tf.data.Dataset.from_tensor_slices(
|
||||
(x_set, y_set)
|
||||
)
|
||||
dataset = dataset.batch(batch_size=batch_size)
|
||||
return dataset
|
||||
return x_set, y_set
|
||||
|
||||
|
@ -285,6 +287,7 @@ def _do_fit(n_samples,
|
|||
n_samples,
|
||||
input_dim,
|
||||
n_outputs,
|
||||
batch_size,
|
||||
generator=generator
|
||||
)
|
||||
y = None
|
||||
|
@ -292,7 +295,12 @@ def _do_fit(n_samples,
|
|||
x = x.shuffle(n_samples//2)
|
||||
batch_size = None
|
||||
else:
|
||||
x, y = _cat_dataset(n_samples, input_dim, n_outputs, generator=generator)
|
||||
x, y = _cat_dataset(
|
||||
n_samples,
|
||||
input_dim,
|
||||
n_outputs,
|
||||
batch_size,
|
||||
generator=generator)
|
||||
if reset_n_samples:
|
||||
n_samples = None
|
||||
|
||||
|
@ -377,6 +385,7 @@ class FitTests(keras_parameterized.TestCase):
|
|||
n_samples,
|
||||
input_dim,
|
||||
n_classes,
|
||||
batch_size,
|
||||
generator=generator
|
||||
)
|
||||
x = x.batch(batch_size)
|
||||
|
|
Loading…
Reference in a new issue