Skip to content

Commit

Permalink
Adding dropout option to predict function
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximilianPi committed Jul 8, 2021
1 parent 726fd45 commit 249a5e3
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 32 deletions.
10 changes: 6 additions & 4 deletions sjSDM/R/sjSDM.R
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,16 @@ print.sjSDM = function(x, ...) {
#' @param newdata newdata for predictions
#' @param SP spatial predictors (e.g. X and Y coordinates)
#' @param type raw or link
#' @param dropout use dropout for predictions or not, only supported for DNNs
#' @param ... optional arguments for compatibility with the generic function, no function implemented
#' @import checkmate
#' @export
predict.sjSDM = function(object, newdata = NULL, SP = NULL, type = c("link", "raw"),...) {
predict.sjSDM = function(object, newdata = NULL, SP = NULL, type = c("link", "raw"), dropout = FALSE,...) {
object = checkModel(object)

assert( checkNull(newdata), checkMatrix(newdata), checkDataFrame(newdata) )
assert( checkNull(SP), checkMatrix(newdata), checkDataFrame(newdata) )
qassert( dropout, "B1")

if(inherits(object, "spatial")) assert_class(object, "spatial")

Expand All @@ -245,7 +247,7 @@ predict.sjSDM = function(object, newdata = NULL, SP = NULL, type = c("link", "ra


if(is.null(newdata)) {
return(object$model$predict(newdata = object$data$X, SP = object$spatial$X, link=link))
return(object$model$predict(newdata = object$data$X, SP = object$spatial$X, link=link, dropout = dropout, ...))
} else {

if(is.data.frame(newdata)) {
Expand All @@ -261,7 +263,7 @@ predict.sjSDM = function(object, newdata = NULL, SP = NULL, type = c("link", "ra
}

}
pred = object$model$predict(newdata = newdata, SP = sp, link=link, ...)
pred = object$model$predict(newdata = newdata, SP = sp, link=link, dropout = dropout, ...)
return(pred)


Expand All @@ -276,7 +278,7 @@ predict.sjSDM = function(object, newdata = NULL, SP = NULL, type = c("link", "ra
newdata = stats::model.matrix(object$formula, data.frame(newdata))
}
}
pred = object$model$predict(newdata = newdata, link=link, ...)
pred = object$model$predict(newdata = newdata, link=link, dropout = dropout, ...)
return(pred)

}
Expand Down
69 changes: 42 additions & 27 deletions sjSDM/inst/python/sjSDM_py/model_sjSDM.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, device: str = "cpu", dtype: str = "float32"):
self.losses = []
self.env = None
self.spatial = None
device, dtype = self._device_and_dtype(device, dtype)
device, dtype = self._device_and_dtype(device, dtype) # type: ignore
self.device = device
self.dtype = dtype
torch.set_default_dtype(self.dtype)
Expand Down Expand Up @@ -61,13 +61,13 @@ def _device_and_dtype(device: Union[int, str, torch.device], dtype: Union[str, t
device = torch.device(device)

if type(dtype) is not str:
return device, dtype
return device, dtype # type: ignore

if dtype == "float32":
return device, torch.float32
return device, torch.float32 # type: ignore
if dtype == "float64":
return device, torch.float64
return device, torch.float32
return device, torch.float64 # type: ignore
return device, torch.float32 # type: ignore


def add_env(self,
Expand Down Expand Up @@ -276,9 +276,9 @@ def build(self,

if self.device.type == 'cuda' and torch.cuda.is_available():
#torch.set_default_tensor_type('torch.cuda.FloatTensor')
self.env.cuda(self.device)
self.env.cuda(self.device) # type: ignore
if self.spatial is not None:
self.spatial.cuda(self.device)
self.spatial.cuda(self.device) # type: ignore
else:
torch.set_default_tensor_type('torch.FloatTensor')

Expand All @@ -290,14 +290,14 @@ def build(self,
self.mixed = mixed
r_dim = self.output_shape
if not diag:
low = -np.sqrt(6.0/(r_dim+df))
high = np.sqrt(6.0/(r_dim+df))
self.sigma = torch.tensor(np.random.uniform(low, high, [r_dim, df]), requires_grad = True, dtype = self.dtype, device = self.device).to(self.device)
low = -np.sqrt(6.0/(r_dim+df)) # type: ignore
high = np.sqrt(6.0/(r_dim+df)) # type: ignore
self.sigma = torch.tensor(np.random.uniform(low, high, [r_dim, df]), requires_grad = True, dtype = self.dtype, device = self.device).to(self.device) # type: ignore
self._loss_function = self._build_loss_function()
self._build_cov_constrain_function(l1 = l1, l2 = l2, reg_on_Cov = reg_on_Cov, reg_on_Diag = reg_on_Diag, inverse = inverse)
self.params.append([self.sigma])
else:
self.sigma = torch.zeros([r_dim, r_dim], dtype = self.dtype, device = self.device).to(self.device)
self.sigma = torch.zeros([r_dim, r_dim], dtype = self.dtype, device = self.device).to(self.device) # type: ignore
self.df = r_dim
self._loss_function = self._build_loss_function()
self._build_cov_constrain_function(l1 = l1, l2 = l2, reg_on_Cov = reg_on_Cov, reg_on_Diag = reg_on_Diag, inverse = inverse)
Expand Down Expand Up @@ -332,7 +332,7 @@ def fit(self,
parallel (int, optional): Use parallelization for DataLoader. Defaults to 0.
early_stopping_training (int, optional): Use early stopping or not. Defaults to -1.
"""
stepSize = np.floor(X.shape[0] / batch_size).astype(int)
stepSize = np.floor(X.shape[0] / batch_size).astype(int) # type: ignore
dataLoader = self._get_DataLoader(X, Y, SP, batch_size, True, parallel)
any_losses = len(self.losses) > 0
batch_loss = np.zeros(stepSize)
Expand All @@ -341,8 +341,8 @@ def fit(self,
df = self.df
alpha = self.alpha

if self.device.type == 'cuda':
device = self.device.type+ ":" + str(self.device.index)
if self.device.type == 'cuda': # type: ignore
device = self.device.type+ ":" + str(self.device.index) # type: ignore
else:
device = 'cpu'

Expand Down Expand Up @@ -377,7 +377,7 @@ def update_func(loss):
sp = sp.to(self.device, non_blocking=True)
self.optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=mixed):
mu = self.env(x) + self.spatial(sp)
mu = self.env(x) + self.spatial(sp) # type: ignore
#tmp(mu: torch.Tensor, Ys: torch.Tensor, sigma: torch.Tensor, batch_size: int, sampling: int, df: int, alpha: float
loss = self._loss_function(mu, y, self.sigma, batch_size, sampling, df, alpha, device, self.dtype)
loss = loss.mean()
Expand All @@ -403,15 +403,15 @@ def update_func(loss):
if counter_early_stopping_training == early_stopping_training:
_ = sys.stdout.write("\nEarly stopping...")
break
self.spatial.eval()
self.spatial.eval() # type: ignore
else:
for epoch in ep_bar:
for step, (x, y) in enumerate(dataLoader):
x = x.to(self.device, non_blocking=True)
y = y.to(self.device, non_blocking=True)
self.optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=mixed):
mu = self.env(x)
mu = self.env(x) # type: ignore
loss = self._loss_function(mu, y, self.sigma, batch_size, sampling, df, alpha, device, self.dtype)
loss = loss.mean()
if any_losses:
Expand Down Expand Up @@ -468,8 +468,8 @@ def logLik(self,
torch.cuda.empty_cache()
any_losses = len(self.losses) > 0

if self.device.type == 'cuda':
device = self.device.type+ ":" + str(self.device.index)
if self.device.type == 'cuda': # type: ignore
device = self.device.type+ ":" + str(self.device.index) # type: ignore
else:
device = 'cpu'

Expand All @@ -481,18 +481,18 @@ def logLik(self,
x = x.to(self.device, non_blocking=True)
y = y.to(self.device, non_blocking=True)
sp = sp.to(self.device, non_blocking=True)
mu = self.env(x) + self.spatial(sp)
mu = self.env(x) + self.spatial(sp) # type: ignore
loss = loss_function(mu, y, self.sigma, x.shape[0], sampling, self.df, self.alpha, device, self.dtype)
logLik.append(loss.data)
else:
for step, (x, y) in enumerate(dataLoader):
x = x.to(self.device, non_blocking=True)
y = y.to(self.device, non_blocking=True)
mu = self.env(x)
mu = self.env(x) # type: ignore
loss = loss_function(mu, y, self.sigma, x.shape[0], sampling, self.df, self.alpha, device, self.dtype)
logLik.append(loss.data)

loss_reg = torch.tensor(0.0, device=self.device, dtype=self.dtype).to(self.device)
loss_reg = torch.tensor(0.0, device=self.device, dtype=self.dtype).to(self.device) # type: ignore
for k in range(len(self.losses)):
loss_reg = loss_reg.add(self.losses[k]())

Expand All @@ -511,7 +511,8 @@ def predict(self,
batch_size: int = 25,
parallel: int = 0,
sampling: int = 100,
link: bool = True) -> np.ndarray:
link: bool = True,
dropout: bool = False) -> np.ndarray:
"""Predict with sjSDM
Args:
Expand All @@ -522,6 +523,7 @@ def predict(self,
parallel (int, optional): Parallelization of DataLoader. Defaults to 0.
sampling (int, optional): Number of MC-samples for each species. Defaults to 100.
link (bool, optional): Linear or response scale. Defaults to True.
dropout (bool, optional): Use dropout during predictions or not. Defaults to False.
Returns:
np.ndarray: Predictions
Expand All @@ -531,26 +533,39 @@ def predict(self,
loss_function = self._build_loss_function(train = False, raw=not link)

pred = []
if self.device.type == 'cuda':
device = self.device.type+ ":" + str(self.device.index)
if self.device.type == 'cuda': # type: ignore
device = self.device.type+ ":" + str(self.device.index) # type: ignore
else:
device = 'cpu'

self.env.eval()
if type(SP) is np.ndarray:
self.spatial.eval()

# TODO: export to method
if dropout:
self.env.train()
if type(SP) is np.ndarray:
self.spatial.train()

if type(SP) is np.ndarray:
for step, (x, sp) in enumerate(dataLoader):
x = x.to(self.device, non_blocking=True)
sp = sp.to(self.device, non_blocking=True)
mu = self.env(x) + self.spatial(sp)
mu = self.env(x) + self.spatial(sp) # type: ignore
# loss = self._loss_function(mu, y, self.sigma, batch_size, sampling, df, alpha, device)
loss = loss_function(mu, self.sigma, x.shape[0], sampling, self.df, self.alpha, device)
pred.append(loss)
self.spatial.eval()
else:
for step, (x) in enumerate(dataLoader):
x = x[0].to(self.device, non_blocking=True)
mu = self.env(x)
mu = self.env(x) # type: ignore
# loss = self._loss_function(mu, y, self.sigma, batch_size, sampling, df, alpha, device)
loss = loss_function(mu, self.sigma, x.shape[0], sampling, self.df, self.alpha, device)
pred.append(loss)

self.env.eval()
return torch.cat(pred, dim = 0).data.cpu().numpy()

def se(self,
Expand Down
5 changes: 4 additions & 1 deletion sjSDM/man/predict.sjSDM.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 249a5e3

Please sign in to comment.