Augmenting implementation to handle new tensorflow _validate_or_infer_batch_size implementation.

This commit is contained in:
Christopher Choquette Choo 2019-08-06 11:00:22 -04:00
parent 136200d0c2
commit 5ef3cec26e
2 changed files with 15 additions and 2 deletions

View file

@ -157,6 +157,8 @@ class BoltOnModel(Model): # pylint: disable=abstract-method
batch_size_ = self._validate_or_infer_batch_size(batch_size,
steps_per_epoch,
x)
if batch_size_ is None:
batch_size_ = 32
# inferring batch_size to be passed to optimizer. batch_size must remain its
# initial value when passed to super().fit()
if batch_size_ is None:
@ -219,6 +221,8 @@ class BoltOnModel(Model): # pylint: disable=abstract-method
batch_size = self._validate_or_infer_batch_size(None,
steps_per_epoch,
generator)
if batch_size is None:
batch_size = 32
with self.optimizer(noise_distribution,
epsilon,
self.layers,

View file

@ -215,7 +215,7 @@ class InitTests(keras_parameterized.TestCase):
clf.compile(optimizer, loss)
def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
def _cat_dataset(n_samples, input_dim, n_classes, batch_size, generator=False):
"""Creates a categorically encoded dataset.
Creates a categorically encoded dataset (y is categorical).
@ -228,6 +228,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
input_dim: input dimensionality
n_classes: output dimensionality
generator: False for array, True for generator
batch_size: The desired batch_size.
Returns:
X as (n_samples, input_dim), Y as (n_samples, n_outputs)
@ -246,6 +247,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
dataset = tf.data.Dataset.from_tensor_slices(
(x_set, y_set)
)
dataset = dataset.batch(batch_size=batch_size)
return dataset
return x_set, y_set
@ -285,6 +287,7 @@ def _do_fit(n_samples,
n_samples,
input_dim,
n_outputs,
batch_size,
generator=generator
)
y = None
@ -292,7 +295,12 @@ def _do_fit(n_samples,
x = x.shuffle(n_samples//2)
batch_size = None
else:
x, y = _cat_dataset(n_samples, input_dim, n_outputs, generator=generator)
x, y = _cat_dataset(
n_samples,
input_dim,
n_outputs,
batch_size,
generator=generator)
if reset_n_samples:
n_samples = None
@ -377,6 +385,7 @@ class FitTests(keras_parameterized.TestCase):
n_samples,
input_dim,
n_classes,
batch_size,
generator=generator
)
x = x.batch(batch_size)