Skip to content

Commit

Permalink
JAX bi-tempered bug fix
Browse files Browse the repository at this point in the history
- fix shape mismatch in custom gradient when len(activations.shape) > 2
- remove static_argnums
  • Loading branch information
eamid authored Dec 22, 2021
1 parent 2d8cc60 commit 1c65c77
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions jax/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _internal_exp_t(u: jnp.ndarray, t: float) -> jnp.ndarray:
functools.partial(_internal_exp_t, t=t), u)


@functools.partial(jax.jit, static_argnums=2)
@jax.jit
def compute_normalization_fixed_point(activations: jnp.ndarray,
t: float,
num_iters: int = 5):
Expand Down Expand Up @@ -93,7 +93,7 @@ def body_fun(carry):
return -log_t(1.0 / logt_partition, t) + mu


@functools.partial(jax.jit, static_argnums=2)
@jax.jit
def compute_normalization_binary_search(activations: jnp.ndarray,
t: float,
num_iters: int = 10):
Expand Down Expand Up @@ -141,7 +141,7 @@ def body_fun(carry):
return logt_partition + mu


@functools.partial(jax.jit, static_argnums=2)
@jax.jit
def compute_tempered_normalization(activations: jnp.ndarray,
t: float,
num_iters: int = 5):
Expand All @@ -154,7 +154,7 @@ def compute_tempered_normalization(activations: jnp.ndarray,
activations)


@functools.partial(jax.jit, static_argnums=2)
@jax.jit
def compute_normalization(activations: jnp.ndarray,
t: float,
num_iters: int = 5):
Expand All @@ -174,7 +174,7 @@ def compute_normalization(activations: jnp.ndarray,
activations)


@functools.partial(jax.jit, static_argnums=2)
@jax.jit
def tempered_sigmoid(activations, t, num_iters=5):
"""Tempered sigmoid function.
Expand All @@ -195,7 +195,7 @@ def tempered_sigmoid(activations, t, num_iters=5):
return jnp.reshape(one_class_probabilities, input_shape)


@functools.partial(jax.jit, static_argnums=2)
@jax.jit
def tempered_softmax(activations, t, num_iters=5):
"""Tempered softmax function.
Expand Down Expand Up @@ -280,7 +280,7 @@ def bi_tempered_logistic_loss(activations,
return loss_values


@functools.partial(jax.jit, static_argnums=5)
@jax.jit
def bi_tempered_logistic_loss_fwd(activations,
labels,
t1,
Expand Down Expand Up @@ -351,6 +351,8 @@ def bi_tempered_logistic_loss_bwd(res, d_loss):
escorts = escorts / jnp.sum(escorts, -1, keepdims=True)
derivative = delta_probs_times_forget_factor - jnp.multiply(
escorts, delta_forget_sum)
if len(d_loss.shape) < len(derivative.shape):
d_loss = jnp.expand_dims(d_loss, -1)
return (jnp.multiply(d_loss, derivative), None, None, None, None, None)


Expand Down

0 comments on commit 1c65c77

Please sign in to comment.