In dp_optimizer_keras_sparse, update iterations
to reflect the number of logical batches, rather than physical batches.
In the current behavior, when using gradient accumulation, the `iterations` variable is incremented at every physical batch, while variables are only updated at every logical batch (where logical batch = accumulation_steps many physical batches). This causes certain optimizers that explicitly depend on `iterations` (such as Adam) to behave very differently under gradient accumulation. With this change, `iterations` is only incremented after each logical batch. PiperOrigin-RevId: 517197044
This commit is contained in:
parent
7ae50c5ca5
commit
52806ba952
4 changed files with 247 additions and 20 deletions
|
@ -109,6 +109,15 @@ py_test(
|
||||||
deps = [":dp_optimizer_keras_sparse"],
|
deps = [":dp_optimizer_keras_sparse"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "dp_optimizer_keras_sparse_distributed_test",
|
||||||
|
timeout = "long",
|
||||||
|
srcs = ["dp_optimizer_keras_sparse_distributed_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [":dp_optimizer_keras_sparse"],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "dp_optimizer_vectorized_test",
|
name = "dp_optimizer_vectorized_test",
|
||||||
timeout = "long",
|
timeout = "long",
|
||||||
|
|
|
@ -185,6 +185,7 @@ def make_sparse_keras_optimizer_class(cls):
|
||||||
self._num_microbatches = num_microbatches
|
self._num_microbatches = num_microbatches
|
||||||
self._was_dp_gradients_called = False
|
self._was_dp_gradients_called = False
|
||||||
self._noise_stddev = None
|
self._noise_stddev = None
|
||||||
|
self._acc_iterations = None
|
||||||
if self._num_microbatches is not None:
|
if self._num_microbatches is not None:
|
||||||
# The loss/gradients is the mean over the microbatches so we
|
# The loss/gradients is the mean over the microbatches so we
|
||||||
# divide the noise by num_microbatches too to obtain the correct
|
# divide the noise by num_microbatches too to obtain the correct
|
||||||
|
@ -202,15 +203,29 @@ def make_sparse_keras_optimizer_class(cls):
|
||||||
def _create_slots(self, var_list):
|
def _create_slots(self, var_list):
|
||||||
super()._create_slots(var_list) # pytype: disable=attribute-error
|
super()._create_slots(var_list) # pytype: disable=attribute-error
|
||||||
if self.gradient_accumulation_steps > 1:
|
if self.gradient_accumulation_steps > 1:
|
||||||
|
# Slots for accumulating gradients.
|
||||||
for var in var_list:
|
for var in var_list:
|
||||||
self.add_slot(var, 'grad_acc')
|
self.add_slot(var, 'grad_acc')
|
||||||
|
if self._acc_iterations is None:
|
||||||
|
# Variable for the iterations, used for bookkeeping when to accumulate
|
||||||
|
# and when to update.
|
||||||
|
self._acc_iterations = self.add_weight(
|
||||||
|
'acc_iterations',
|
||||||
|
shape=[],
|
||||||
|
trainable=False,
|
||||||
|
dtype=tf.int64,
|
||||||
|
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
|
||||||
|
)
|
||||||
|
|
||||||
def _prepare_local(self, var_device, var_dtype, apply_state):
|
def _prepare_local(self, var_device, var_dtype, apply_state):
|
||||||
super()._prepare_local(var_device, var_dtype, apply_state) # pytype: disable=attribute-error
|
super()._prepare_local(var_device, var_dtype, apply_state) # pytype: disable=attribute-error
|
||||||
if self.gradient_accumulation_steps > 1:
|
if self.gradient_accumulation_steps > 1:
|
||||||
apply_update = tf.math.equal(
|
apply_update = tf.math.equal(
|
||||||
tf.math.floormod(self.iterations + 1,
|
tf.math.floormod(
|
||||||
self.gradient_accumulation_steps), 0)
|
self._acc_iterations + 1, self.gradient_accumulation_steps
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
grad_scaler = tf.cast(1. / self.gradient_accumulation_steps, var_dtype)
|
grad_scaler = tf.cast(1. / self.gradient_accumulation_steps, var_dtype)
|
||||||
apply_state[(var_device, var_dtype)].update({
|
apply_state[(var_device, var_dtype)].update({
|
||||||
'apply_update': apply_update,
|
'apply_update': apply_update,
|
||||||
|
@ -218,7 +233,7 @@ def make_sparse_keras_optimizer_class(cls):
|
||||||
})
|
})
|
||||||
|
|
||||||
def _resource_apply(self, accum_op, grad, var, apply_state=None):
|
def _resource_apply(self, accum_op, grad, var, apply_state=None):
|
||||||
"""Help method for _resource_apply_dense and _resource_apply_sparse."""
|
"""Helper method for _resource_apply_dense and _resource_apply_sparse."""
|
||||||
if self.gradient_accumulation_steps > 1:
|
if self.gradient_accumulation_steps > 1:
|
||||||
var_device, var_dtype = var.device, var.dtype.base_dtype
|
var_device, var_dtype = var.device, var.dtype.base_dtype
|
||||||
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
|
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
|
||||||
|
@ -235,9 +250,8 @@ def make_sparse_keras_optimizer_class(cls):
|
||||||
tf.zeros_like(grad_acc),
|
tf.zeros_like(grad_acc),
|
||||||
use_locking=self._use_locking,
|
use_locking=self._use_locking,
|
||||||
read_value=False)
|
read_value=False)
|
||||||
accum_op(grad_acc, grad, use_locking=self._use_locking)
|
with tf.control_dependencies([accum_op(grad_acc, grad)]):
|
||||||
return tf.cond(
|
return tf.cond(coefficients['apply_update'], _update_grad, tf.no_op)
|
||||||
coefficients['apply_update'], _update_grad, lambda: tf.no_op()) # pylint: disable=unnecessary-lambda
|
|
||||||
else:
|
else:
|
||||||
grad = tf.convert_to_tensor(grad)
|
grad = tf.convert_to_tensor(grad)
|
||||||
grad = grad + self._generate_noise(grad)
|
grad = grad + self._generate_noise(grad)
|
||||||
|
@ -246,9 +260,11 @@ def make_sparse_keras_optimizer_class(cls):
|
||||||
|
|
||||||
def _resource_apply_dense(self, grad, var, apply_state=None):
|
def _resource_apply_dense(self, grad, var, apply_state=None):
|
||||||
"""Handles dense gradients."""
|
"""Handles dense gradients."""
|
||||||
def _accum_op(grad_acc, grad, use_locking):
|
def _accum_op(grad_acc, grad):
|
||||||
return grad_acc.assign_add(
|
return grad_acc.assign_add(
|
||||||
grad, use_locking=use_locking, read_value=False)
|
grad, use_locking=self._use_locking, read_value=False
|
||||||
|
)
|
||||||
|
|
||||||
return self._resource_apply(_accum_op, grad, var, apply_state)
|
return self._resource_apply(_accum_op, grad, var, apply_state)
|
||||||
|
|
||||||
# This method is implemented the same as that in optimizer_v2.py. We
|
# This method is implemented the same as that in optimizer_v2.py. We
|
||||||
|
@ -271,13 +287,49 @@ def make_sparse_keras_optimizer_class(cls):
|
||||||
|
|
||||||
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
|
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
|
||||||
"""Handles deduped sparse gradients."""
|
"""Handles deduped sparse gradients."""
|
||||||
def _accum_op(grad_acc, sparse_delta, use_locking):
|
def _accum_op(grad_acc, sparse_delta):
|
||||||
return grad_acc.scatter_add(
|
return grad_acc.scatter_add(
|
||||||
sparse_delta=sparse_delta, use_locking=use_locking)
|
sparse_delta=sparse_delta, use_locking=self._use_locking
|
||||||
|
)
|
||||||
|
|
||||||
sparse_delta = tf.IndexedSlices(
|
sparse_delta = tf.IndexedSlices(
|
||||||
values=grad, indices=indices, dense_shape=var.shape)
|
values=grad, indices=indices, dense_shape=var.shape)
|
||||||
return self._resource_apply(_accum_op, sparse_delta, var, apply_state)
|
return self._resource_apply(_accum_op, sparse_delta, var, apply_state)
|
||||||
|
|
||||||
|
def _distributed_apply(
|
||||||
|
self, distribution, grads_and_vars, apply_state, name
|
||||||
|
):
|
||||||
|
apply_op = super()._distributed_apply(
|
||||||
|
distribution, grads_and_vars, apply_state, name
|
||||||
|
)
|
||||||
|
if self.gradient_accumulation_steps > 1:
|
||||||
|
# The original _distributed_apply increments self.iterations after each
|
||||||
|
# call. But we want to increment it only after each logical batch is
|
||||||
|
# processed, so optimizers that explicitly use self.iterations in their
|
||||||
|
# updates (such as Adam) can use the correct value.
|
||||||
|
def increment_acc_iterations():
|
||||||
|
# Always use locking when updating the steps, so we don't under-count
|
||||||
|
# the steps (which could invalidate privacy accounting).
|
||||||
|
return self._acc_iterations.assign_add(
|
||||||
|
1, use_locking=True, read_value=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def assign_iterations():
|
||||||
|
return self.iterations.assign(
|
||||||
|
tf.math.floordiv(
|
||||||
|
self._acc_iterations, self.gradient_accumulation_steps
|
||||||
|
),
|
||||||
|
use_locking=True,
|
||||||
|
read_value=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with tf.control_dependencies([apply_op]):
|
||||||
|
with tf.control_dependencies([increment_acc_iterations()]):
|
||||||
|
return assign_iterations()
|
||||||
|
else:
|
||||||
|
# No accumulation.
|
||||||
|
return apply_op
|
||||||
|
|
||||||
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
|
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
|
||||||
"""DP-SGD version of base class method."""
|
"""DP-SGD version of base class method."""
|
||||||
self._was_dp_gradients_called = True
|
self._was_dp_gradients_called = True
|
||||||
|
|
|
@ -0,0 +1,153 @@
|
||||||
|
# Copyright 2023, The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests DPSparseKerasSGDOptimizer in distributed training."""
|
||||||
|
import contextlib
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.optimizers import dp_optimizer_keras_sparse
|
||||||
|
|
||||||
|
ds_combinations = tf.__internal__.distribute.combinations
|
||||||
|
|
||||||
|
|
||||||
|
STRATEGIES = [
|
||||||
|
ds_combinations.one_device_strategy,
|
||||||
|
ds_combinations.parameter_server_strategy_1worker_2ps_cpu,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedTrainingTest(parameterized.TestCase, tf.test.TestCase):
|
||||||
|
|
||||||
|
@ds_combinations.generate(
|
||||||
|
tf.__internal__.test.combinations.combine(
|
||||||
|
strategy=STRATEGIES, mode="eager"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def test_training_works(self, strategy):
|
||||||
|
if isinstance(strategy, tf.distribute.OneDeviceStrategy):
|
||||||
|
strategy_scope = contextlib.nullcontext()
|
||||||
|
else:
|
||||||
|
strategy_scope = strategy.scope()
|
||||||
|
|
||||||
|
def make_model():
|
||||||
|
inputs = tf.keras.Input((1000,))
|
||||||
|
dense = tf.keras.layers.Dense(
|
||||||
|
units=1, use_bias=False, kernel_initializer=tf.initializers.ones()
|
||||||
|
)
|
||||||
|
outputs = dense(inputs)
|
||||||
|
return tf.keras.models.Model(inputs=inputs, outputs=outputs)
|
||||||
|
|
||||||
|
x = tf.ones(shape=[5000, 1000])
|
||||||
|
y = tf.zeros(shape=[5000])
|
||||||
|
with strategy_scope:
|
||||||
|
model = make_model()
|
||||||
|
clip = 100.0
|
||||||
|
noise_mult = 0.01
|
||||||
|
acc_steps = 5
|
||||||
|
batch_size = 10
|
||||||
|
opt = dp_optimizer_keras_sparse.DPSparseKerasSGDOptimizer(
|
||||||
|
l2_norm_clip=clip,
|
||||||
|
noise_multiplier=noise_mult,
|
||||||
|
gradient_accumulation_steps=acc_steps,
|
||||||
|
learning_rate=0.001,
|
||||||
|
)
|
||||||
|
model.compile(
|
||||||
|
loss=tf.keras.losses.MeanAbsoluteError(
|
||||||
|
reduction=tf.keras.losses.Reduction.NONE
|
||||||
|
),
|
||||||
|
optimizer=opt,
|
||||||
|
)
|
||||||
|
history = model.fit(
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
epochs=2,
|
||||||
|
steps_per_epoch=500,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
self.assertIn("loss", history.history)
|
||||||
|
# total steps: 1000 (2 epochs, 500 steps/epoch)
|
||||||
|
# accumulation steps: 5
|
||||||
|
# expected_iterations = total steps / accumulation steps
|
||||||
|
expected_iterations = 1000 / acc_steps # = 200
|
||||||
|
# The loss is |w.x - y| (where w is the dense layer weights).
|
||||||
|
# The gradient is sign(w.x - y)x. With the choice of x, y, the gradient
|
||||||
|
# becomes x.
|
||||||
|
# So each gradient update is w <- w - learning_rate*1 + noise
|
||||||
|
expected_params = 1 - 0.001 * expected_iterations
|
||||||
|
expected_noise = (
|
||||||
|
0.001
|
||||||
|
* clip
|
||||||
|
* noise_mult
|
||||||
|
* np.sqrt(expected_iterations)
|
||||||
|
/ (acc_steps * batch_size)
|
||||||
|
)
|
||||||
|
self.assertEqual(opt.iterations.numpy(), expected_iterations)
|
||||||
|
self.assertAllClose(
|
||||||
|
np.mean(model.trainable_variables[0].numpy()),
|
||||||
|
expected_params, # 0.8
|
||||||
|
# stddev = expected_noise/√1000 (since we're averaging 1000 samples)
|
||||||
|
# we set atol to 4 stddev
|
||||||
|
atol=4 * expected_noise / np.sqrt(1000), # 0.0358
|
||||||
|
)
|
||||||
|
self.assertAllClose(
|
||||||
|
np.std(model.trainable_variables[0].numpy()),
|
||||||
|
expected_noise, # 0.2828
|
||||||
|
atol=4 * expected_noise / np.sqrt(1000), # 0.0358
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_spawn_exe_path():
|
||||||
|
"""Set the path to the executable for spawned processes.
|
||||||
|
|
||||||
|
This utility searches for the binary the parent process is using, and sets
|
||||||
|
the executable of multiprocessing's context accordingly.
|
||||||
|
It is branched from tensorflow/python/distribute/multi_process_lib.py, the
|
||||||
|
only change being that it additionally looks under "org_tensorflow_privacy".
|
||||||
|
"""
|
||||||
|
if sys.argv[0].endswith(".py"):
|
||||||
|
|
||||||
|
def guess_path(package_root):
|
||||||
|
# If all we have is a python module path, we'll need to make a guess for
|
||||||
|
# the actual executable path.
|
||||||
|
if "bazel-out" in sys.argv[0] and package_root in sys.argv[0]:
|
||||||
|
# Guess the binary path under bazel. For target
|
||||||
|
# //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the
|
||||||
|
# argv[0] is in the form of
|
||||||
|
# /.../tensorflow/python/distribute/input_lib_test.py
|
||||||
|
# and the binary is
|
||||||
|
# /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
|
||||||
|
package_root_base = sys.argv[0][: sys.argv[0].rfind(package_root)]
|
||||||
|
binary = os.environ["TEST_TARGET"][2:].replace(":", "/", 1)
|
||||||
|
possible_path = os.path.join(package_root_base, package_root, binary)
|
||||||
|
if os.access(possible_path, os.X_OK):
|
||||||
|
return possible_path
|
||||||
|
return None
|
||||||
|
|
||||||
|
path = (
|
||||||
|
guess_path("org_tensorflow")
|
||||||
|
or guess_path("org_keras")
|
||||||
|
or guess_path("org_tensorflow_privacy")
|
||||||
|
)
|
||||||
|
if path is not None:
|
||||||
|
sys.argv[0] = path
|
||||||
|
multiprocessing.get_context().set_executable(sys.argv[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_set_spawn_exe_path()
|
||||||
|
tf.__internal__.distribute.multi_process_runner.test_main()
|
|
@ -338,21 +338,25 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
# After first call to optimizer values didn't change
|
# After first call to optimizer values didn't change
|
||||||
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
|
self.assertAllCloseAccordingToType([[1.0, 2.0]], var0)
|
||||||
self.assertAllCloseAccordingToType([3.0], var1)
|
self.assertAllCloseAccordingToType([3.0], var1)
|
||||||
|
self.assertEqual(opt.iterations, 0)
|
||||||
|
|
||||||
opt.minimize(loss2, [var0, var1])
|
opt.minimize(loss2, [var0, var1])
|
||||||
# After second call to optimizer updates were applied
|
# After second call to optimizer updates were applied
|
||||||
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
|
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
|
||||||
self.assertAllCloseAccordingToType([2.0], var1)
|
self.assertAllCloseAccordingToType([2.0], var1)
|
||||||
|
self.assertEqual(opt.iterations, 1)
|
||||||
|
|
||||||
opt.minimize(loss2, [var0, var1])
|
opt.minimize(loss2, [var0, var1])
|
||||||
# After third call to optimizer values didn't change
|
# After third call to optimizer values didn't change
|
||||||
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
|
self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0)
|
||||||
self.assertAllCloseAccordingToType([2.0], var1)
|
self.assertAllCloseAccordingToType([2.0], var1)
|
||||||
|
self.assertEqual(opt.iterations, 1)
|
||||||
|
|
||||||
opt.minimize(loss2, [var0, var1])
|
opt.minimize(loss2, [var0, var1])
|
||||||
# After fourth call to optimizer updates were applied again
|
# After fourth call to optimizer updates were applied again
|
||||||
self.assertAllCloseAccordingToType([[-4.0, -0.5]], var0)
|
self.assertAllCloseAccordingToType([[-4.0, -0.5]], var0)
|
||||||
self.assertAllCloseAccordingToType([1.0], var1)
|
self.assertAllCloseAccordingToType([1.0], var1)
|
||||||
|
self.assertEqual(opt.iterations, 2)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('DPSparseKerasSGDOptimizer 1',
|
('DPSparseKerasSGDOptimizer 1',
|
||||||
|
@ -388,6 +392,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
self.assertNotAllClose([[1.0, 2.0]], var0)
|
self.assertNotAllClose([[1.0, 2.0]], var0)
|
||||||
self.assertNotAllClose([3.0], var1)
|
self.assertNotAllClose([3.0], var1)
|
||||||
|
self.assertEqual(opt.iterations, 1)
|
||||||
|
|
||||||
def testKerasModelBaselineSaving(self):
|
def testKerasModelBaselineSaving(self):
|
||||||
"""Tests that DP optimizers work with tf.keras.Model."""
|
"""Tests that DP optimizers work with tf.keras.Model."""
|
||||||
|
@ -455,10 +460,15 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
model.fit(train_data, train_labels, batch_size=8, epochs=1, shuffle=False)
|
model.fit(train_data, train_labels, batch_size=8, epochs=1, shuffle=False)
|
||||||
|
|
||||||
@parameterized.named_parameters(('1', 1), ('None', None))
|
@parameterized.named_parameters(
|
||||||
def testKerasModelBaselineNoNoise(self, num_microbatches):
|
('no_microbatch_no_accumulation', False, False),
|
||||||
|
('no_microbatch_accumulation', False, True),
|
||||||
|
('microbatch_no_accumulation', True, False),
|
||||||
|
('microbatch_accumulation', True, True),
|
||||||
|
)
|
||||||
|
def testKerasModelBaselineNoNoise(self, microbatch, accumulate):
|
||||||
"""Tests that DP optimizers work with tf.keras.Model."""
|
"""Tests that DP optimizers work with tf.keras.Model."""
|
||||||
|
acc_steps = 2 if accumulate else 1
|
||||||
model = tf.keras.models.Sequential(layers=[
|
model = tf.keras.models.Sequential(layers=[
|
||||||
tf.keras.layers.Dense(
|
tf.keras.layers.Dense(
|
||||||
1,
|
1,
|
||||||
|
@ -471,22 +481,25 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
optimizer = dp_optimizer.DPSparseKerasSGDOptimizer(
|
optimizer = dp_optimizer.DPSparseKerasSGDOptimizer(
|
||||||
l2_norm_clip=100.0,
|
l2_norm_clip=100.0,
|
||||||
noise_multiplier=0.0,
|
noise_multiplier=0.0,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=None if microbatch else 1,
|
||||||
learning_rate=0.05)
|
learning_rate=0.05,
|
||||||
|
gradient_accumulation_steps=acc_steps,
|
||||||
|
)
|
||||||
loss = tf.keras.losses.MeanSquaredError(reduction='none')
|
loss = tf.keras.losses.MeanSquaredError(reduction='none')
|
||||||
model.compile(optimizer, loss)
|
model.compile(optimizer, loss)
|
||||||
|
|
||||||
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
|
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
|
||||||
true_bias = np.array([6.0]).astype(np.float32)
|
true_bias = np.array([6.0]).astype(np.float32)
|
||||||
train_data = np.random.normal(scale=3.0, size=(1000, 4)).astype(np.float32)
|
train_data = np.random.normal(scale=3.0, size=(2000, 4)).astype(np.float32)
|
||||||
train_labels = np.matmul(train_data,
|
train_labels = np.matmul(train_data, true_weights) + true_bias
|
||||||
true_weights) + true_bias + np.random.normal(
|
|
||||||
scale=0.0, size=(1000, 1)).astype(np.float32)
|
|
||||||
|
|
||||||
model.fit(train_data, train_labels, batch_size=8, epochs=1, shuffle=False)
|
model.fit(train_data, train_labels, batch_size=10, epochs=1, shuffle=False)
|
||||||
|
|
||||||
self.assertAllClose(model.get_weights()[0], true_weights, atol=0.05)
|
self.assertAllClose(model.get_weights()[0], true_weights, atol=0.05)
|
||||||
self.assertAllClose(model.get_weights()[1], true_bias, atol=0.05)
|
self.assertAllClose(model.get_weights()[1], true_bias, atol=0.05)
|
||||||
|
# Check that the optimizer's iterations equal the number of logical batches.
|
||||||
|
total_batches = 200
|
||||||
|
self.assertEqual(optimizer.iterations.numpy(), total_batches / acc_steps)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in a new issue