diff --git a/tutorials/lm_dpsgd_tutorial.py b/tutorials/lm_dpsgd_tutorial.py index 93a8971..8f72ff6 100644 --- a/tutorials/lm_dpsgd_tutorial.py +++ b/tutorials/lm_dpsgd_tutorial.py @@ -139,8 +139,8 @@ def load_data(): test_dataset = tfds.load(name='lm1b/subwords8k', split=tfds.Split.TEST, batch_size=10000) - train_data = next(tfds.as_numpy(train_dataset)) - test_data = next(tfds.as_numpy(test_dataset)) + train_data = next(iter(tfds.as_numpy(train_dataset))) + test_data = next(iter(tfds.as_numpy(test_dataset))) train_data = train_data['text'].flatten() test_data = test_data['text'].flatten() else: