Skip to content

Commit

Permalink
add systematic loss functions
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Apr 6, 2024
1 parent 8e0659b commit 7fb14b8
Show file tree
Hide file tree
Showing 17 changed files with 560 additions and 588 deletions.
31 changes: 31 additions & 0 deletions braintools/metric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,34 @@
# limitations under the License.
# ==============================================================================

from ._classification import *
from ._classification import __all__ as _classification_all
from ._correlation import *
from ._correlation import __all__ as _correlation_all
from ._fenchel_young import *
from ._fenchel_young import __all__ as _fenchel_young_all
from ._firings import *
from ._firings import __all__ as _firings_all
from ._lfp import *
from ._lfp import __all__ as _lfp_all
from ._ranking import *
from ._ranking import __all__ as _ranking_all
from ._regression import *
from ._regression import __all__ as _regression_all
from ._smoothing import *
from ._smoothing import __all__ as _smoothing_all

__all__ = (
_classification_all
+ _correlation_all
+ _fenchel_young_all
+ _firings_all
+ _lfp_all
+ _ranking_all
+ _regression_all
+ _smoothing_all
)
del (_classification_all, _correlation_all,
_fenchel_young_all, _firings_all,
_lfp_all, _ranking_all,
_regression_all, _smoothing_all)
21 changes: 18 additions & 3 deletions braintools/metric/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@
import jax
import jax.numpy as jnp

__all__ = [
'sigmoid_binary_cross_entropy',
'hinge_loss',
'perceptron_loss',
'softmax_cross_entropy',
'softmax_cross_entropy_with_integer_labels',
'multiclass_hinge_loss',
'multiclass_perceptron_loss',
'poly_loss_cross_entropy',
'kl_divergence',
'kl_divergence_with_log_targets',
'convex_kl_divergence',
'ctc_loss',
'ctc_loss_with_forward_probs',
'sigmoid_focal_loss',
]


def assert_is_float(array):
assert bc.math.is_float(array), 'Array must be float.'
Expand Down Expand Up @@ -103,7 +120,7 @@ def perceptron_loss(
Returns:
loss value.
"""
assert jnp.shape(predictor_outputs) == jnp.shape(targets)
assert jnp.shape(predictor_outputs) == jnp.shape(targets), 'shape mismatch'
return jnp.maximum(0, - predictor_outputs * targets)


Expand Down Expand Up @@ -545,9 +562,7 @@ def sigmoid_focal_loss(
ce_loss = sigmoid_binary_cross_entropy(logits, labels)
p_t = p * labels + (1 - p) * (1 - labels)
loss = ce_loss * ((1 - p_t) ** gamma)

weighted = lambda loss_arg: (alpha * labels + (1 - alpha) * (1 - labels)) * loss_arg
not_weighted = lambda loss_arg: loss_arg

loss = jax.lax.cond(alpha >= 0, weighted, not_weighted, loss)
return loss
Loading

0 comments on commit 7fb14b8

Please sign in to comment.