Remove dependence on six in clip_and_aggregate_gradients.py.
PiperOrigin-RevId: 481750014
This commit is contained in:
parent
d5538fccbb
commit
4aa531faa4
2 changed files with 5 additions and 14 deletions
|
@ -15,10 +15,7 @@ py_library(
|
||||||
"clip_and_aggregate_gradients.py",
|
"clip_and_aggregate_gradients.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = ["//third_party/tensorflow/python/ops/parallel_for:control_flow_ops"],
|
||||||
"//third_party/py/six",
|
|
||||||
"//third_party/tensorflow/python/ops/parallel_for:control_flow_ops",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
|
|
@ -16,10 +16,8 @@
|
||||||
|
|
||||||
Modified from tape.jacobian to support sparse gradients.
|
Modified from tape.jacobian to support sparse gradients.
|
||||||
"""
|
"""
|
||||||
import sys
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
import six
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.python.ops.parallel_for import control_flow_ops # pylint: disable=g-direct-tensorflow-import
|
from tensorflow.python.ops.parallel_for import control_flow_ops # pylint: disable=g-direct-tensorflow-import
|
||||||
|
@ -192,14 +190,10 @@ def clip_and_aggregate_gradients(
|
||||||
try:
|
try:
|
||||||
output = control_flow_ops.pfor(loop_fn, target_size)
|
output = control_flow_ops.pfor(loop_fn, target_size)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
six.reraise(
|
raise ValueError(
|
||||||
ValueError,
|
'Encountered an exception while vectorizing the '
|
||||||
ValueError(
|
'batch_jacobian computation. Vectorization can be disabled by '
|
||||||
str(err) + '\nEncountered an exception while vectorizing the '
|
'setting experimental_use_pfor to False.') from err
|
||||||
'jacobian computation. Consider using a non-vectorized version, '
|
|
||||||
'i.e. by computing the gradient for each output sequentially.'),
|
|
||||||
sys.exc_info()[2])
|
|
||||||
|
|
||||||
grads = []
|
grads = []
|
||||||
for i, out in enumerate(output):
|
for i, out in enumerate(output):
|
||||||
if out is not None:
|
if out is not None:
|
||||||
|
|
Loading…
Reference in a new issue