Remove dependence on six in clip_and_aggregate_gradients.py.

PiperOrigin-RevId: 481750014
This commit is contained in:
Steve Chien 2022-10-17 15:06:56 -07:00 committed by A. Unique TensorFlower
parent d5538fccbb
commit 4aa531faa4
2 changed files with 5 additions and 14 deletions

View file

@ -15,10 +15,7 @@ py_library(
"clip_and_aggregate_gradients.py", "clip_and_aggregate_gradients.py",
], ],
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = ["//third_party/tensorflow/python/ops/parallel_for:control_flow_ops"],
"//third_party/py/six",
"//third_party/tensorflow/python/ops/parallel_for:control_flow_ops",
],
) )
py_library( py_library(

View file

@ -16,10 +16,8 @@
Modified from tape.jacobian to support sparse gradients. Modified from tape.jacobian to support sparse gradients.
""" """
import sys
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
import six
import tensorflow as tf import tensorflow as tf
from tensorflow.python.ops.parallel_for import control_flow_ops # pylint: disable=g-direct-tensorflow-import from tensorflow.python.ops.parallel_for import control_flow_ops # pylint: disable=g-direct-tensorflow-import
@ -192,14 +190,10 @@ def clip_and_aggregate_gradients(
try: try:
output = control_flow_ops.pfor(loop_fn, target_size) output = control_flow_ops.pfor(loop_fn, target_size)
except ValueError as err: except ValueError as err:
six.reraise( raise ValueError(
ValueError, 'Encountered an exception while vectorizing the '
ValueError( 'batch_jacobian computation. Vectorization can be disabled by '
str(err) + '\nEncountered an exception while vectorizing the ' 'setting experimental_use_pfor to False.') from err
'jacobian computation. Consider using a non-vectorized version, '
'i.e. by computing the gradient for each output sequentially.'),
sys.exc_info()[2])
grads = [] grads = []
for i, out in enumerate(output): for i, out in enumerate(output):
if out is not None: if out is not None: