diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py index 9411d61..a15d5c5 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py @@ -20,6 +20,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import class GradNormTpuTest(embedding_test.GradNormTest): def setUp(self): + tf.config.experimental.disable_mlir_bridge() super(embedding_test.GradNormTest, self).setUp() self.strategy = common_test_utils.create_tpu_strategy() self.assertIn('TPU', self.strategy.extended.worker_devices[0]) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py index 283a671..98acc87 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py @@ -20,6 +20,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest): def setUp(self): + tf.config.experimental.disable_mlir_bridge() super(nlp_on_device_embedding_test.GradNormTest, self).setUp() self.strategy = common_test_utils.create_tpu_strategy() self.assertIn('TPU', self.strategy.extended.worker_devices[0])