Fixing missing args.

This commit is contained in:
Christopher Choquette Choo 2019-07-27 13:54:19 -04:00
parent 0317ce8077
commit 92f97ae32c
3 changed files with 28 additions and 27 deletions

View file

@ -76,18 +76,17 @@ class BoltonModel(Model): # pylint: disable=abstract-method
def compile(self, def compile(self,
optimizer, optimizer,
loss, loss,
metrics=None,
loss_weights=None,
sample_weight_mode=None,
weighted_metrics=None,
target_tensors=None,
distribute=None,
kernel_initializer=tf.initializers.GlorotUniform, kernel_initializer=tf.initializers.GlorotUniform,
**kwargs): # pylint: disable=arguments-differ **kwargs): # pylint: disable=arguments-differ
"""See super class. Default optimizer used in Bolton method is SGD. """See super class. Default optimizer used in Bolton method is SGD.
Missing args. Args:
optimizer: The optimizer to use. This will be automatically wrapped
with the Bolton Optimizer.
loss: The loss function to use. Must be a StrongConvex loss (extend the
StrongConvexMixin).
kernel_initializer: The kernel initializer to use for the single layer.
kwargs: kwargs to keras Model.compile. See super.
""" """
if not isinstance(loss, StrongConvexMixin): if not isinstance(loss, StrongConvexMixin):
raise ValueError('loss function must be a Strongly Convex and therefore ' raise ValueError('loss function must be a Strongly Convex and therefore '
@ -104,15 +103,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
optimizer = optimizers.get(optimizer) optimizer = optimizers.get(optimizer)
optimizer = Bolton(optimizer, loss) optimizer = Bolton(optimizer, loss)
super(BoltonModel, self).compile(optimizer, super(BoltonModel, self).compile(optimizer, loss=loss, **kwargs)
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode,
weighted_metrics=weighted_metrics,
target_tensors=target_tensors,
distribute=distribute,
**kwargs)
def fit(self, def fit(self,
x=None, x=None,

View file

@ -263,7 +263,12 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
def test_project(self, r, shape, n_out, init_value, result): def test_project(self, r, shape, n_out, init_value, result):
"""test that a fn of Bolton optimizer is working as expected. """test that a fn of Bolton optimizer is working as expected.
Missing args: Args:
r: Radius value for StrongConvex loss function.
shape: input_dimensionality
n_out: output dimensionality
init_value: the initial value for 'constant' kernel initializer
result: the expected output after projection.fFF
""" """
tf.random.set_seed(1) tf.random.set_seed(1)
@ -536,7 +541,8 @@ class SchedulerTest(keras_parameterized.TestCase):
""" test that attribute of internal optimizer is correctly rerouted to """ test that attribute of internal optimizer is correctly rerouted to
the internal optimizer the internal optimizer
Missing args Args:
err_msg: The expected error message from the scheduler bad call.
""" """
scheduler = opt.GammaBetaDecreasingStep() scheduler = opt.GammaBetaDecreasingStep()
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
@ -559,7 +565,9 @@ class SchedulerTest(keras_parameterized.TestCase):
Test that attribute of internal optimizer is correctly rerouted to the Test that attribute of internal optimizer is correctly rerouted to the
internal optimizer internal optimizer
Missing Args: Args:
step: step number to 'GammaBetaDecreasingStep' 'Scheduler'.
res: expected result from call to 'GammaBetaDecreasingStep' 'Scheduler'.
""" """
beta = _ops.convert_to_tensor_v2(2, dtype=tf.float32) beta = _ops.convert_to_tensor_v2(2, dtype=tf.float32)
gamma = _ops.convert_to_tensor_v2(1, dtype=tf.float32) gamma = _ops.convert_to_tensor_v2(1, dtype=tf.float32)

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tutorial for bolton module, the model and the optimizer.""" """Tutorial for bolton module, the model and the optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf # pylint: disable=wrong-import-position import tensorflow as tf # pylint: disable=wrong-import-position
from privacy.bolton import losses # pylint: disable=wrong-import-position from privacy.bolton import losses # pylint: disable=wrong-import-position