Remove dependence on six in clip_and_aggregate_gradients.py.
PiperOrigin-RevId: 481750014
This commit is contained in:
parent
d5538fccbb
commit
4aa531faa4
2 changed files with 5 additions and 14 deletions
|
@ -15,10 +15,7 @@ py_library(
|
|||
"clip_and_aggregate_gradients.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//third_party/py/six",
|
||||
"//third_party/tensorflow/python/ops/parallel_for:control_flow_ops",
|
||||
],
|
||||
deps = ["//third_party/tensorflow/python/ops/parallel_for:control_flow_ops"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
|
|
@ -16,10 +16,8 @@
|
|||
|
||||
Modified from tape.jacobian to support sparse gradients.
|
||||
"""
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import six
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.ops.parallel_for import control_flow_ops # pylint: disable=g-direct-tensorflow-import
|
||||
|
@ -192,14 +190,10 @@ def clip_and_aggregate_gradients(
|
|||
try:
|
||||
output = control_flow_ops.pfor(loop_fn, target_size)
|
||||
except ValueError as err:
|
||||
six.reraise(
|
||||
ValueError,
|
||||
ValueError(
|
||||
str(err) + '\nEncountered an exception while vectorizing the '
|
||||
'jacobian computation. Consider using a non-vectorized version, '
|
||||
'i.e. by computing the gradient for each output sequentially.'),
|
||||
sys.exc_info()[2])
|
||||
|
||||
raise ValueError(
|
||||
'Encountered an exception while vectorizing the '
|
||||
'batch_jacobian computation. Vectorization can be disabled by '
|
||||
'setting experimental_use_pfor to False.') from err
|
||||
grads = []
|
||||
for i, out in enumerate(output):
|
||||
if out is not None:
|
||||
|
|
Loading…
Reference in a new issue