Disable MLIR bridge for the test points that MLIR bridge silently fails
PiperOrigin-RevId: 676660290
This commit is contained in:
parent
e8856835a6
commit
d965556ebb
2 changed files with 2 additions and 0 deletions
|
@ -20,6 +20,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
class GradNormTpuTest(embedding_test.GradNormTest):
|
class GradNormTpuTest(embedding_test.GradNormTest):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
tf.config.experimental.disable_mlir_bridge()
|
||||||
super(embedding_test.GradNormTest, self).setUp()
|
super(embedding_test.GradNormTest, self).setUp()
|
||||||
self.strategy = common_test_utils.create_tpu_strategy()
|
self.strategy = common_test_utils.create_tpu_strategy()
|
||||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
|
|
@ -20,6 +20,7 @@ from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import
|
||||||
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
|
class GradNormTpuTest(nlp_on_device_embedding_test.GradNormTest):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
tf.config.experimental.disable_mlir_bridge()
|
||||||
super(nlp_on_device_embedding_test.GradNormTest, self).setUp()
|
super(nlp_on_device_embedding_test.GradNormTest, self).setUp()
|
||||||
self.strategy = common_test_utils.create_tpu_strategy()
|
self.strategy = common_test_utils.create_tpu_strategy()
|
||||||
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
|
||||||
|
|
Loading…
Reference in a new issue