Skip to content

Commit

Permalink
[JAX] Replace uses of jnp.array in types with jnp.ndarray.
Browse files Browse the repository at this point in the history
`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation.

Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Change uses of `jnp.array` to `jnp.ndarray`.

PiperOrigin-RevId: 555366510
Change-Id: Ic9535b45be797b92d5dc2b7322c51800aa9a350f
  • Loading branch information
hawkinsp authored and copybara-github committed Aug 10, 2023
1 parent b4c99fa commit f61a18b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
20 changes: 11 additions & 9 deletions lightweight_mmm/optimize_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@
jax.jit,
static_argnames=("media_mix_model", "media_input_shape", "target_scaler",
"media_scaler"))
def _objective_function(extra_features: jnp.ndarray,
media_mix_model: lightweight_mmm.LightweightMMM,
media_input_shape: Tuple[int,
int], media_gap: Optional[int],
target_scaler: Optional[preprocessing.CustomScaler],
media_scaler: preprocessing.CustomScaler,
geo_ratio: jnp.array,
seed: Optional[int],
media_values: jnp.ndarray) -> jnp.float64:
def _objective_function(
extra_features: jnp.ndarray,
media_mix_model: lightweight_mmm.LightweightMMM,
media_input_shape: Tuple[int, int],
media_gap: Optional[int],
target_scaler: Optional[preprocessing.CustomScaler],
media_scaler: preprocessing.CustomScaler,
geo_ratio: jnp.ndarray,
seed: Optional[int],
media_values: jnp.ndarray,
) -> jnp.float64:
"""Objective function to calculate the sum of all predictions of the model.
Args:
Expand Down
7 changes: 4 additions & 3 deletions lightweight_mmm/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,10 +668,11 @@ def _create_shaded_line_plot(predictions: jnp.ndarray,


def _call_fit_plotter(
predictions: jnp.array,
target: jnp.array,
predictions: jnp.ndarray,
target: jnp.ndarray,
interval_mid_range: float,
digits: int) -> matplotlib.figure.Figure:
digits: int,
) -> matplotlib.figure.Figure:
"""Calls the shaded line plot once for national and N times for geo models.
Args:
Expand Down

0 comments on commit f61a18b

Please sign in to comment.