forked from microsoft/LightGBM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added model downloader R wrapper (microsoft#356)
- Loading branch information
1 parent
d287118
commit 12bff62
Showing
2 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# Copyright (C) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
DEFAULT_URL = "https://mmlspark.azureedge.net/datasets/CNTKModels/" | ||
|
||
#' A class for downloading CNTK pretrained models in R. To download all models use the downloadModels | ||
#' function. To browse models from the microsoft server please use remoteModels. | ||
#' | ||
#' Creates the ModelDownloader. | ||
#' | ||
#' @param sc A spark context for interfacing between python and java | ||
#' @param localPath The folder to save models to | ||
#' @param serverURL The location of the model Server, beware this default can change! | ||
#' @param ... Optional arguments; currently unused. | ||
#' | ||
#' @family Model downloader | ||
#' | ||
#' @export | ||
smd_model_downloader <- function(sc, localPath, serverURL=DEFAULT_URL, ...) { | ||
session <- spark_session(sc) | ||
env <- new.env(parent = emptyenv()) | ||
env$model <- "com.microsoft.ml.spark.ModelDownloader" | ||
downloader <- invoke_new(sc, env$model, session, localPath, serverURL) | ||
} | ||
|
||
#' Downloads the model by given name | ||
#' | ||
#' @param smd_model_downloader The model downloader | ||
#' @param name The name of the model | ||
#' @param ... Optional arguments; currently unused. | ||
#' | ||
#' @family Model downloader | ||
#' | ||
#' @export | ||
smd_download_by_name <- function(model_downloader, name, ...) { | ||
model <- invoke(model_downloader, "downloadByName", name) | ||
} | ||
|
||
#' Downloads models stored locally on the filesystem | ||
#' | ||
#' @param smd_model_downloader The model downloader | ||
#' @param ... Optional arguments; currently unused. | ||
#' | ||
#' @family Model downloader | ||
#' | ||
#' @export | ||
smd_local_models <- function(model_downloader, ...) { | ||
model <- invoke(model_downloader, "localModels") | ||
} | ||
|
||
#' Downloads models stored remotely. | ||
#' | ||
#' @param smd_model_downloader The model downloader | ||
#' @param ... Optional arguments; currently unused. | ||
#' | ||
#' @family Model downloader | ||
#' | ||
#' @export | ||
smd_remote_models <- function(model_downloader, ...) { | ||
model <- invoke(model_downloader, "remoteModels") | ||
} | ||
|
||
#' Gets the name of the downloaded model | ||
#' | ||
#' @param model The downloaded model | ||
#' @param ... Optional arguments; currently unused. | ||
#' | ||
#' @family Model downloader | ||
#' | ||
#' @export | ||
smd_get_model_name <- function(model, ...) { | ||
name <- invoke(model, "name") | ||
} | ||
|
||
#' Gets the location of the model's bytes | ||
#' | ||
#' @param model The downloaded model | ||
#' @param ... Optional arguments; currently unused. | ||
#' | ||
#' @family Model downloader | ||
#' | ||
#' @export | ||
smd_get_model_uri <- function(model, ...) { | ||
uri <- invoke(invoke(model, "uri"), "toString") | ||
} | ||
|
||
#' Gets the domain that the model operates on | ||
#' | ||
#' @param model The downloaded model | ||
#' @param ... Optional arguments; currently unused. | ||
#' | ||
#' @family Model downloader | ||
#' | ||
#' @export | ||
smd_get_model_type <- function(model, ...) { | ||
name <- invoke(model, "modelType") | ||
} | ||
|
||
#' Gets the sha256 hash of the models bytes | ||
#' | ||
#' @param model The downloaded model | ||
#' @param ... Optional arguments; currently unused. | ||
#' | ||
#' @family Model downloader | ||
#' | ||
#' @export | ||
smd_get_model_hash <- function(model, ...) { | ||
name <- invoke(model, "hash") | ||
} | ||
|
||
#' Gets the size of the model in bytes | ||
#' | ||
#' @param model The downloaded model | ||
#' @param ... Optional arguments; currently unused. | ||
#' | ||
#' @family Model downloader | ||
#' | ||
#' @export | ||
smd_get_model_size <- function(model, ...) { | ||
name <- invoke(model, "size") | ||
} | ||
|
||
#' Gets the node which represents the input | ||
#' | ||
#' @param model The downloaded model | ||
#' @param ... Optional arguments; currently unused. | ||
#' | ||
#' @family Model downloader | ||
#' | ||
#' @export | ||
smd_get_model_input_node <- function(model, ...) { | ||
name <- invoke(model, "inputNode") | ||
} | ||
|
||
#' Gets the number of layers of the model | ||
#' | ||
#' @param model The downloaded model | ||
#' @param ... Optional arguments; currently unused. | ||
#' | ||
#' @family Model downloader | ||
#' | ||
#' @export | ||
smd_get_model_num_layers <- function(model, ...) { | ||
name <- invoke(model, "numLayers") | ||
} | ||
|
||
#' Gets the names of nodes that represent layers in the network | ||
#' | ||
#' @param model The downloaded model | ||
#' @param ... Optional arguments; currently unused. | ||
#' | ||
#' @family Model downloader | ||
#' | ||
#' @export | ||
smd_get_model_layer_names <- function(model, ...) { | ||
name <- invoke(model, "layerNames") | ||
} |