Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NN activation fns #811

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 168 additions & 7 deletions R/d.R
Original file line number Diff line number Diff line change
Expand Up @@ -624,24 +624,185 @@
}
)

.rxD$ReLU <- list(
function(x) {
paste0("dReLU(", x, ")")
}
)

.rxD$dReLU <- list(
function(x) {
paste0("0")
}
)

.rxD$GELU <- list(
function(x) {
paste0("dGELU(", x, ")")
}
)

.rxD$dGELU <- list(
function(x) {
paste0("d2GELU(", x, ")")
}
)

.rxD$d2GELU <- list(
function(x) {
paste0("d3GELU(", x, ")")
}
)

.rxD$d3GELU <- list(
function(x) {
paste0("d4GELU(", x, ")")
}
)

.rxD$ELU <- list(
function(x, alpha) {
paste0("dELU(", x, ", ", alpha, ")")
},
function(x, alpha) {
paste0("dELUa(", x, ", ", alpha, ")")
})

.rxD$dELU <- list(
function(x, alpha) {
paste0("d2ELU(", x, ", ", alpha, ")")
},
function(x, alpha) {
paste0("d2aELU(", x, ", ", alpha, ")")
})

.rxD$dELUa <- list(
function(x, alpha) {
paste0("d2ELUa(", x, ", ", alpha, ")")
},
function(x, alpha) {
paste0("0")
}
)
.rxD$d2ELUa <- list(
function(x, alpha) {
paste0("d2ELUa(", x, ", ", alpha, ")")
},
function(x, alpha) {
paste0("0")
}
)

.rxD$d2ELU <- list(
function(x, alpha) {
paste0("d2ELU(", x, ", ", alpha, ")")
},
function(x, alpha) {
paste0("d2aELU(", x, ", ", alpha, ")")
})

.rxD$d2aELU <- list(
function(x, alpha) {
paste0("d2aELU(", x, ", ", alpha, ")")
},
function(x, alpha) {
paste0("0")
})

.rxD$softplus <- list(
function(x) {
paste0("dsoftplus(", x, ")")
})

.rxD$dsoftplus <- list(
function(x) {
paste0("d2softplus(", x, ")")
})

.rxD$d2softplus <- list(
function(x) {
paste0("d3softplus(", x, ")")
})

.rxD$d3softplus <- list(
function(x) {
paste0("d4softplus(", x, ")")
})

.rxD$SELU <- list(
function(x) {
paste0("dSELU(", x, ")")
})


.rxD$lReLU <- list(
function(x) {
paste0("dlReLU(", x, ")")
}
)

.rxD$dlReLU <- list(
function(x) {
paste0("0")
}
)

.rxD$PReLU <- list(
function(x, alpha) {
paste0("dPReLU(", x, ",", alpha, ")")
},
function(x, alpha) {
paste0("dPReLUa(", x, ",", alpha, ")")
})

.rxD$dPReLU <- list(
function(x, alpha) {
paste0("0")
},
function(x, alpha) {
paste0("dPReLUa1(", x, ",", alpha, ")")
})

.rxD$dPReLUa <- list(
function(x, alpha) {
paste0("dPReLUa1(", x, ",", alpha, ")")
},
function(x, alpha) {
paste0("0")
})

.rxD$dPReLUa1 <- list(
function(x, alpha) {
paste0("0")
},
function(x, alpha) {
paste0("0")
}
)

.rxD$Swish <- list(
function(x) {
paste0("dSwish(", x, ")")
}
)

#' This gives the derivative table for rxode2
#'
#' This will help allow registration of functions in `rxode2`
#'
#' This will help allow registration of functions in `rxode2`
#'
#' @return Derivative table environment for rxode2
#' @details
#'
#'
#' This environment is a derivative table;
#'
#'
#' For example:
#'
#'
#' Derivative(f(a,b,c), a) = fa()
#' Derivative(f(a,b,c), b) = fb()
#' Derivative(f(a,b,c), c) = fc()
#'
#'
#' Then the derivative table for `f` would be:
#'
#'
#' assign("f", list(fa(a,b,c), fb(a,b,c), fc(a,b,c)), rxode2parseD())
#'
#' fa translates the arguments to the derivative with respect to a
Expand Down
31 changes: 30 additions & 1 deletion R/symengine.R
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,36 @@ regIfOrElse <- rex::rex(or(regIf, regElse))
"llikXCauchyDscale"=4,
"llikXNorm"=4,
"llikXNormDmean"=4,
"llikXNormDsd"=4
"llikXNormDsd"=4,
"ReLU"=1,
"dReLU"=1,
"GELU"=1,
"dGELU"=1,
"d2GELU"=1,
"d3GELU"=1,
"d4GELU"=1,
"ELU"=2,
"dELU"=2,
"d2ELU"=2,
"d2aELU"=2,
"dELUa"=2,
"d2ELUa"=2,
"softplus"=1,
"dsoftplus"=1,
"d2softplus"=1,
"d3softplus"=1,
"d4softplus"=1,
"SELU"=1,
"dSELU"=1,
"lReLU"=1,
"dlReLU"=1,
"PReLU"=2,
"dPReLU"=2,
"d2PReLU"=2,
"dPReLUa"=2,
"dPReLUa1"=2,
"Swish"=1,
"dSwish"=1
)

.rxOnly <- c(
Expand Down
135 changes: 135 additions & 0 deletions inst/include/rxode2_model_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,141 @@ static inline double _safe_log_(double a, rx_solve *rx) {
return log(a);
}
}

static inline double ReLU(double x) {
return (x > 0.0) ? x : 0.0;
}

static inline double dReLU(double x) {
return (x > 0.0) ? 1.0 : 0.0;
}

//
static inline double GELU(double x) {
return 0.5 * x * (1.0 + erf(x * M_SQRT1_2));
}

static inline double dGELU(double x) {
return 0.5 * (1.0 + erf(x * M_SQRT1_2)) + x * M_1_SQRT_2PI * exp(-0.5 * x * x);
}

static inline double d2GELU(double x) {
return (2.0- x*x) * exp(-0.5* x * x)*M_1_SQRT_2PI;
}

static inline double d3GELU(double x) {
return x * exp(-0.5 * x * x) * M_1_SQRT_2PI * (x * x - 4.0);
}

static inline double d4GELU(double x) {
return exp(-0.5*x*x)*M_1_SQRT_2PI*(7.0*x*x - 4.0 - x*x*x*x);
}

static inline double ELU(double x, double alpha) {
return (x > 0.0) ? x : (exp(x) - 1.0) * alpha;
}

// derivative of ELU with respect to x
static inline double dELU(double x, double alpha) {
return (x > 0.0) ? 1.0 : exp(x)*alpha;
}

// derivative of dELU with respect to x
static inline double d2ELU(double x, double alpha) {
return (x > 0.0) ? 0.0 : exp(x)*alpha;
}

// derivative of dELU with respect to alpha
static inline double d2aELU(double x, double alpha) {
return (x > 0.0) ? 0.0 : exp(x);
}

// derivative of ELU with respect to alpha
static inline double dELUa(double x, double alpha) {
return (x > 0.0) ? 0.0 : (exp(x) - 1.0);
}
// derivative of dELAa with respect to x
static inline double d2ELUa(double x, double alpha) {
return (x > 0.0) ? 0.0 : exp(x);
}

static inline double softplus(double x) {
return log(1.0 + exp(x));
}

static inline double dsoftplus(double x) {
return 1.0 / (1.0 + exp(-x));
}

static inline double d2softplus(double x) {
double ex = exp(x);
return ex / ((1.0 + ex) * (1.0 + ex));
}

static inline double d3softplus(double x) {
double ex = exp(-x);
double ex1 = (1.0 + ex);
return 2.0*exp(-2.0*x)/(ex1*ex1*ex1) - 1.0*ex/(ex1*ex1);
}

static inline double d4softplus(double x) {
double ex = exp(-x);
double ex1 = (1.0 + ex);
return 6.0*exp(-3.0*x)/(ex1*ex1*ex1*ex1) -
6.0*exp(-2.0*x)/(ex1*ex1*ex1) +
1.0*ex/(ex1*ex1);
}

static inline double SELU(double x) {
#define alpha 1.6732632423543772848170429916717
#define scale 1.0507009873554804934193349852946
return (x > 0.0) ? scale * x : scale * alpha * (exp(x) - 1.0);
#undef alpha
#undef scale
}

static inline double dSELU(double x) {
#define alpha 1.6732632423543772848170429916717
#define scale 1.0507009873554804934193349852946
return (x > 0.0) ? scale : scale * alpha * exp(x);
#undef alpha
#undef scale
}

static inline double lReLU(double x) {
return (x > 0.0) ? x : 0.01 * x;
}

static inline double dlReLU(double x) {
return (x > 0.0) ? 1.0 : 0.01;
}

static inline double PReLU(double x, double alpha) {
return (x >= 0.0) ? x : alpha * x;
}

static inline double dPReLU(double x, double alpha) {
return (x >= 0.0) ? 1.0 : alpha;
}

static inline double dPReLUa(double x, double alpha) {
return (x >= 0.0) ? 0.0 : x;
}

static inline double dPReLUa1(double x, double alpha) {
return (x >= 0.0) ? 0.0 : 1.0;
}

static inline double Swish(double x) {
return x / (1.0 + exp(-x));
}

static inline double dSwish(double x) {
double ex = exp(x);
double den = 1.0 + ex;
return ex / (den * den) + x * ex / (den * den);
}

#define _safe_log(a) _safe_log_(a, _solveData)
static inline double _div0_(double denom, rx_solve *rx) {
if (rx->safeZero) {
Expand Down
Loading
Loading