forked from 626_privacy/tensorflow_privacy
Rename jax.experimental.optimizers -> jax.example_libraries.optimizers
Why? The former name has been deprecated since JAX version 0.2.25, released in November 2021 (see https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0225-nov-10-2021), and will soon be removed. PiperOrigin-RevId: 465670868
This commit is contained in:
parent
a9abfbc244
commit
6718ae2636
1 changed files with 3 additions and 2 deletions
|
@ -25,7 +25,7 @@ import jax
|
||||||
import collections
|
import collections
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import jax.experimental.optimizers
|
import jax.example_libraries.optimizers
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
@ -109,7 +109,8 @@ def run():
|
||||||
lams = np.array(lambdas)
|
lams = np.array(lambdas)
|
||||||
|
|
||||||
# Use Adam, because thinking hard is overrated we have magic pixie dust.
|
# Use Adam, because thinking hard is overrated we have magic pixie dust.
|
||||||
init_1, opt_update_1, get_params_1 = jax.experimental.optimizers.adam(.01)
|
init_1, opt_update_1, get_params_1 = \
|
||||||
|
jax.example_libraries.optimizers.adam(.01)
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def update_1(i, opt_state, gs):
|
def update_1(i, opt_state, gs):
|
||||||
return opt_update_1(i, gs, opt_state)
|
return opt_update_1(i, gs, opt_state)
|
||||||
|
|
Loading…
Reference in a new issue