Fixing missing args.

This commit is contained in:
Christopher Choquette Choo 2019-07-27 13:54:19 -04:00
parent 0317ce8077
commit 92f97ae32c
3 changed files with 28 additions and 27 deletions

View file

@ -76,18 +76,17 @@ class BoltonModel(Model): # pylint: disable=abstract-method
def compile(self,
optimizer,
loss,
metrics=None,
loss_weights=None,
sample_weight_mode=None,
weighted_metrics=None,
target_tensors=None,
distribute=None,
kernel_initializer=tf.initializers.GlorotUniform,
**kwargs): # pylint: disable=arguments-differ
"""See super class. Default optimizer used in Bolton method is SGD.
Missing args.
Args:
optimizer: The optimizer to use. This will be automatically wrapped
with the Bolton Optimizer.
loss: The loss function to use. Must be a StrongConvex loss (extend the
StrongConvexMixin).
kernel_initializer: The kernel initializer to use for the single layer.
kwargs: kwargs to keras Model.compile. See super.
"""
if not isinstance(loss, StrongConvexMixin):
raise ValueError('loss function must be a Strongly Convex and therefore '
@ -104,15 +103,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
optimizer = optimizers.get(optimizer)
optimizer = Bolton(optimizer, loss)
super(BoltonModel, self).compile(optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode,
weighted_metrics=weighted_metrics,
target_tensors=target_tensors,
distribute=distribute,
**kwargs)
super(BoltonModel, self).compile(optimizer, loss=loss, **kwargs)
def fit(self,
x=None,

View file

@ -263,7 +263,12 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
def test_project(self, r, shape, n_out, init_value, result):
"""test that a fn of Bolton optimizer is working as expected.
Missing args:
Args:
r: Radius value for StrongConvex loss function.
shape: input_dimensionality
n_out: output dimensionality
init_value: the initial value for 'constant' kernel initializer
result: the expected output after projection.fFF
"""
tf.random.set_seed(1)
@ -536,7 +541,8 @@ class SchedulerTest(keras_parameterized.TestCase):
""" test that attribute of internal optimizer is correctly rerouted to
the internal optimizer
Missing args
Args:
err_msg: The expected error message from the scheduler bad call.
"""
scheduler = opt.GammaBetaDecreasingStep()
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
@ -559,7 +565,9 @@ class SchedulerTest(keras_parameterized.TestCase):
Test that attribute of internal optimizer is correctly rerouted to the
internal optimizer
Missing Args:
Args:
step: step number to 'GammaBetaDecreasingStep' 'Scheduler'.
res: expected result from call to 'GammaBetaDecreasingStep' 'Scheduler'.
"""
beta = _ops.convert_to_tensor_v2(2, dtype=tf.float32)
gamma = _ops.convert_to_tensor_v2(1, dtype=tf.float32)

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tutorial for bolton module, the model and the optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=wrong-import-position
from privacy.bolton import losses # pylint: disable=wrong-import-position