Merge pull request #59 from georgianpartners:master

PiperOrigin-RevId: 262652439
This commit is contained in:
A. Unique TensorFlower 2019-08-09 16:04:33 -07:00
commit 4164243a99
2 changed files with 15 additions and 2 deletions

View file

@ -157,6 +157,8 @@ class BoltOnModel(Model): # pylint: disable=abstract-method
batch_size_ = self._validate_or_infer_batch_size(batch_size, batch_size_ = self._validate_or_infer_batch_size(batch_size,
steps_per_epoch, steps_per_epoch,
x) x)
if batch_size_ is None:
batch_size_ = 32
# inferring batch_size to be passed to optimizer. batch_size must remain its # inferring batch_size to be passed to optimizer. batch_size must remain its
# initial value when passed to super().fit() # initial value when passed to super().fit()
if batch_size_ is None: if batch_size_ is None:
@ -219,6 +221,8 @@ class BoltOnModel(Model): # pylint: disable=abstract-method
batch_size = self._validate_or_infer_batch_size(None, batch_size = self._validate_or_infer_batch_size(None,
steps_per_epoch, steps_per_epoch,
generator) generator)
if batch_size is None:
batch_size = 32
with self.optimizer(noise_distribution, with self.optimizer(noise_distribution,
epsilon, epsilon,
self.layers, self.layers,

View file

@ -215,7 +215,7 @@ class InitTests(keras_parameterized.TestCase):
clf.compile(optimizer, loss) clf.compile(optimizer, loss)
def _cat_dataset(n_samples, input_dim, n_classes, generator=False): def _cat_dataset(n_samples, input_dim, n_classes, batch_size, generator=False):
"""Creates a categorically encoded dataset. """Creates a categorically encoded dataset.
Creates a categorically encoded dataset (y is categorical). Creates a categorically encoded dataset (y is categorical).
@ -227,6 +227,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
n_samples: number of rows n_samples: number of rows
input_dim: input dimensionality input_dim: input dimensionality
n_classes: output dimensionality n_classes: output dimensionality
batch_size: The desired batch_size
generator: False for array, True for generator generator: False for array, True for generator
Returns: Returns:
@ -246,6 +247,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
dataset = tf.data.Dataset.from_tensor_slices( dataset = tf.data.Dataset.from_tensor_slices(
(x_set, y_set) (x_set, y_set)
) )
dataset = dataset.batch(batch_size=batch_size)
return dataset return dataset
return x_set, y_set return x_set, y_set
@ -285,6 +287,7 @@ def _do_fit(n_samples,
n_samples, n_samples,
input_dim, input_dim,
n_outputs, n_outputs,
batch_size,
generator=generator generator=generator
) )
y = None y = None
@ -292,7 +295,12 @@ def _do_fit(n_samples,
x = x.shuffle(n_samples//2) x = x.shuffle(n_samples//2)
batch_size = None batch_size = None
else: else:
x, y = _cat_dataset(n_samples, input_dim, n_outputs, generator=generator) x, y = _cat_dataset(
n_samples,
input_dim,
n_outputs,
batch_size,
generator=generator)
if reset_n_samples: if reset_n_samples:
n_samples = None n_samples = None
@ -377,6 +385,7 @@ class FitTests(keras_parameterized.TestCase):
n_samples, n_samples,
input_dim, input_dim,
n_classes, n_classes,
batch_size,
generator=generator generator=generator
) )
x = x.batch(batch_size) x = x.batch(batch_size)