Skip to content

Commit

Permalink
In the examples codebase, no allowing incorrectly formatted targets t…
Browse files Browse the repository at this point in the history
…o be passed to cross entropy loss registration, replacing these instead with None. An informative warning will be printed when this happens.

PiperOrigin-RevId: 717971219
  • Loading branch information
james-martens authored and KfacJaxDev committed Jan 21, 2025
1 parent 9182734 commit 410bd73
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion examples/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import types

from typing import Any, Sequence, Mapping
import warnings

import haiku as hk
import jax
Expand Down Expand Up @@ -92,11 +93,21 @@ def softmax_cross_entropy(
raise NotImplementedError("Non-constant loss weights are not currently "
"supported.")

if logits.ndim == labels.ndim + 1:
targets = labels.reshape([-1])
else:
targets = None
warnings.warn("Incorrectly formatted labels detected for softmax cross "
"entropy loss registration. Perhaps you passed 1-hot "
"vectors or asked for label smoothing? These will be "
"ignored for registration purposes, making the use of "
"empirical Fisher estimators impossible.")

# Currently the registration functions only support 2D array inputs values
# for `logits`, and so we need the reshapes below.
registration_module.register_softmax_cross_entropy_loss(
logits.reshape([-1, logits.shape[-1]]),
targets=labels.reshape([-1]),
targets=targets,
mask=mask.reshape([-1]) if mask is not None else None,
weight=weight,
**extra_registration_kwargs)
Expand Down

0 comments on commit 410bd73

Please sign in to comment.