diff --git a/src/NeuralNet/losses/BCE.hpp b/src/NeuralNet/losses/BCE.hpp index 411a8b7..6788195 100644 --- a/src/NeuralNet/losses/BCE.hpp +++ b/src/NeuralNet/losses/BCE.hpp @@ -17,6 +17,10 @@ class BCE : public Loss { -(yTrim.array() * oTrim.array().log() + (1.0 - yTrim.array()) * (1.0 - oTrim.array()).log()); + if (loss.array().isNaN().any()) + throw std::runtime_error( + "NaN value encountered. Inputs might be too big"); + return loss.sum(); }