Disable MLIR bridge for the test points that MLIR bridge silently fails

PiperOrigin-RevId: 676660290
This commit is contained in:
A. Unique TensorFlower 2024-09-19 19:51:02 -07:00
parent e8856835a6
commit d965556ebb
2 changed files with 2 additions and 0 deletions

View file

@ -20,6 +20,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
class GradNormTpuTest(embedding_test.GradNormTest): class GradNormTpuTest(embedding_test.GradNormTest):
def setUp(self): def setUp(self):
tf.config.experimental.disable_mlir_bridge()
super(embedding_test.GradNormTest, self).setUp() super(embedding_test.GradNormTest, self).setUp()
self.strategy = common_test_utils.create_tpu_strategy() self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0]) self.assertIn('TPU', self.strategy.extended.worker_devices[0])

View file

@ -20,6 +20,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest): class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
def setUp(self): def setUp(self):
tf.config.experimental.disable_mlir_bridge()
super(nlp_on_device_embedding_test.GradNormTest, self).setUp() super(nlp_on_device_embedding_test.GradNormTest, self).setUp()
self.strategy = common_test_utils.create_tpu_strategy() self.strategy = common_test_utils.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0]) self.assertIn('TPU', self.strategy.extended.worker_devices[0])