From d965556ebb67bd62626830339478e9ebab7ab9bd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Sep 2024 19:51:02 -0700 Subject: [PATCH] Disable MLIR bridge for the test points that MLIR bridge silently fails PiperOrigin-RevId: 676660290 --- .../registry_functions/embedding_tpu_test.py | 1 + .../registry_functions/nlp_on_device_embedding_tpu_test.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py index 9411d61..a15d5c5 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py @@ -20,6 +20,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import class GradNormTpuTest(embedding_test.GradNormTest): def setUp(self): + tf.config.experimental.disable_mlir_bridge() super(embedding_test.GradNormTest, self).setUp() self.strategy = common_test_utils.create_tpu_strategy() self.assertIn('TPU', self.strategy.extended.worker_devices[0]) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py index 283a671..98acc87 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py @@ -20,6 +20,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest): def setUp(self): + tf.config.experimental.disable_mlir_bridge() super(nlp_on_device_embedding_test.GradNormTest, self).setUp() self.strategy = common_test_utils.create_tpu_strategy() self.assertIn('TPU', self.strategy.extended.worker_devices[0])