diff --git a/tensorflow_privacy/privacy/bolt_on/models.py b/tensorflow_privacy/privacy/bolt_on/models.py index 49bf466..4e325e3 100644 --- a/tensorflow_privacy/privacy/bolt_on/models.py +++ b/tensorflow_privacy/privacy/bolt_on/models.py @@ -272,8 +272,9 @@ class BoltOnModel(Model): # pylint: disable=abstract-method num_samples = sum(class_counts) weighted_counts = tf.dtypes.cast( tf.math.multiply(num_classes, class_counts), self._dtype) - class_weights = tf.Variable(num_samples, dtype=self._dtype) / \ - tf.Variable(weighted_counts, dtype=self._dtype) + class_weights = ( + tf.Variable(num_samples, dtype=self._dtype) / + tf.Variable(weighted_counts, dtype=self._dtype)) else: class_weights = _ops.convert_to_tensor_v2(class_weights) if len(class_weights.shape) != 1: