Rename jax.experimental.optimizers -> jax.example_libraries.optimizers
Why? The former name has been deprecated since JAX version 0.2.25, released in November 2021 (see https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0225-nov-10-2021), and will soon be removed. PiperOrigin-RevId: 465670868
This commit is contained in:
parent
a9abfbc244
commit
6718ae2636
1 changed files with 3 additions and 2 deletions
|
@ -25,7 +25,7 @@ import jax
|
|||
import collections
|
||||
from PIL import Image
|
||||
|
||||
import jax.experimental.optimizers
|
||||
import jax.example_libraries.optimizers
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
@ -109,7 +109,8 @@ def run():
|
|||
lams = np.array(lambdas)
|
||||
|
||||
# Use Adam, because thinking hard is overrated we have magic pixie dust.
|
||||
init_1, opt_update_1, get_params_1 = jax.experimental.optimizers.adam(.01)
|
||||
init_1, opt_update_1, get_params_1 = \
|
||||
jax.example_libraries.optimizers.adam(.01)
|
||||
@jax.jit
|
||||
def update_1(i, opt_state, gs):
|
||||
return opt_update_1(i, gs, opt_state)
|
||||
|
|
Loading…
Reference in a new issue