Skip to content

Commit

Permalink
Remove references to deprecated submodule jax.abstract_arrays
Browse files Browse the repository at this point in the history
Use jax.core instead (see jax-ml/jax#16271)

PiperOrigin-RevId: 538790761
Change-Id: I12ca976910a5dd383c1ec0778530c96475acbd7d
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Jun 8, 2023
1 parent 95243a8 commit 84b2afc
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions saxml/server/jax/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class SerializedPjitFunction:
# Serializable IR for the device computation.
ir: Any
# Abstract arrays for the flattend inputs.
flat_global_in_avals: Sequence[jax.abstract_arrays.ShapedArray]
flat_global_in_avals: Sequence[jax.core.ShapedArray]
# Abstract arrays for the flattend outputs.
flat_global_out_avals: Sequence[jax.abstract_arrays.ShapedArray]
flat_global_out_avals: Sequence[jax.core.ShapedArray]
# Whether the compilation uses a tuple to hold all args.
tuple_args: bool
# Shardings for the flattened inputs.
Expand Down

0 comments on commit 84b2afc

Please sign in to comment.