diff --git a/jax/loss.py b/jax/loss.py index 297015f..5f5c0d1 100644 --- a/jax/loss.py +++ b/jax/loss.py @@ -58,7 +58,7 @@ def _internal_exp_t(u: jnp.ndarray, t: float) -> jnp.ndarray: functools.partial(_internal_exp_t, t=t), u) -@functools.partial(jax.jit, static_argnums=2) +@jax.jit def compute_normalization_fixed_point(activations: jnp.ndarray, t: float, num_iters: int = 5): @@ -93,7 +93,7 @@ def body_fun(carry): return -log_t(1.0 / logt_partition, t) + mu -@functools.partial(jax.jit, static_argnums=2) +@jax.jit def compute_normalization_binary_search(activations: jnp.ndarray, t: float, num_iters: int = 10): @@ -141,7 +141,7 @@ def body_fun(carry): return logt_partition + mu -@functools.partial(jax.jit, static_argnums=2) +@jax.jit def compute_tempered_normalization(activations: jnp.ndarray, t: float, num_iters: int = 5): @@ -154,7 +154,7 @@ def compute_tempered_normalization(activations: jnp.ndarray, activations) -@functools.partial(jax.jit, static_argnums=2) +@jax.jit def compute_normalization(activations: jnp.ndarray, t: float, num_iters: int = 5): @@ -174,7 +174,7 @@ def compute_normalization(activations: jnp.ndarray, activations) -@functools.partial(jax.jit, static_argnums=2) +@jax.jit def tempered_sigmoid(activations, t, num_iters=5): """Tempered sigmoid function. @@ -195,7 +195,7 @@ def tempered_sigmoid(activations, t, num_iters=5): return jnp.reshape(one_class_probabilities, input_shape) -@functools.partial(jax.jit, static_argnums=2) +@jax.jit def tempered_softmax(activations, t, num_iters=5): """Tempered softmax function. @@ -280,7 +280,7 @@ def bi_tempered_logistic_loss(activations, return loss_values -@functools.partial(jax.jit, static_argnums=5) +@jax.jit def bi_tempered_logistic_loss_fwd(activations, labels, t1, @@ -351,6 +351,8 @@ def bi_tempered_logistic_loss_bwd(res, d_loss): escorts = escorts / jnp.sum(escorts, -1, keepdims=True) derivative = delta_probs_times_forget_factor - jnp.multiply( escorts, delta_forget_sum) + if len(d_loss.shape) < len(derivative.shape): + d_loss = jnp.expand_dims(d_loss, -1) return (jnp.multiply(d_loss, derivative), None, None, None, None, None)