diff --git a/privacy/bolt_on/models.py b/privacy/bolt_on/models.py index 7cdcccd..c504103 100644 --- a/privacy/bolt_on/models.py +++ b/privacy/bolt_on/models.py @@ -157,6 +157,8 @@ class BoltOnModel(Model): # pylint: disable=abstract-method batch_size_ = self._validate_or_infer_batch_size(batch_size, steps_per_epoch, x) + if batch_size_ is None: + batch_size_ = 32 # inferring batch_size to be passed to optimizer. batch_size must remain its # initial value when passed to super().fit() if batch_size_ is None: @@ -219,6 +221,8 @@ class BoltOnModel(Model): # pylint: disable=abstract-method batch_size = self._validate_or_infer_batch_size(None, steps_per_epoch, generator) + if batch_size is None: + batch_size = 32 with self.optimizer(noise_distribution, epsilon, self.layers, diff --git a/privacy/bolt_on/models_test.py b/privacy/bolt_on/models_test.py index 522f686..772f792 100644 --- a/privacy/bolt_on/models_test.py +++ b/privacy/bolt_on/models_test.py @@ -215,7 +215,7 @@ class InitTests(keras_parameterized.TestCase): clf.compile(optimizer, loss) -def _cat_dataset(n_samples, input_dim, n_classes, generator=False): +def _cat_dataset(n_samples, input_dim, n_classes, batch_size, generator=False): """Creates a categorically encoded dataset. Creates a categorically encoded dataset (y is categorical). @@ -227,6 +227,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, generator=False): n_samples: number of rows input_dim: input dimensionality n_classes: output dimensionality + batch_size: The desired batch_size generator: False for array, True for generator Returns: @@ -246,6 +247,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, generator=False): dataset = tf.data.Dataset.from_tensor_slices( (x_set, y_set) ) + dataset = dataset.batch(batch_size=batch_size) return dataset return x_set, y_set @@ -285,6 +287,7 @@ def _do_fit(n_samples, n_samples, input_dim, n_outputs, + batch_size, generator=generator ) y = None @@ -292,7 +295,12 @@ def _do_fit(n_samples, x = x.shuffle(n_samples//2) batch_size = None else: - x, y = _cat_dataset(n_samples, input_dim, n_outputs, generator=generator) + x, y = _cat_dataset( + n_samples, + input_dim, + n_outputs, + batch_size, + generator=generator) if reset_n_samples: n_samples = None @@ -377,6 +385,7 @@ class FitTests(keras_parameterized.TestCase): n_samples, input_dim, n_classes, + batch_size, generator=generator ) x = x.batch(batch_size)