diff --git a/tutorials/lm_dpsgd_tutorial.py b/tutorials/lm_dpsgd_tutorial.py index 8bc1cff..67398ea 100644 --- a/tutorials/lm_dpsgd_tutorial.py +++ b/tutorials/lm_dpsgd_tutorial.py @@ -134,7 +134,8 @@ def load_data(): 'using a substitute dataset from the tensorflow_datasets module.') train_dataset = tfds.load(name='lm1b/subwords8k', split=tfds.Split.TRAIN, - batch_size=NB_TRAIN) + batch_size=NB_TRAIN, + shuffle_files=True) test_dataset = tfds.load(name='lm1b/subwords8k', split=tfds.Split.TEST, batch_size=10000)