diff --git a/src/codegen/src/main/scala/WrapperGenerator.scala b/src/codegen/src/main/scala/WrapperGenerator.scala index 865bebc5cfd9..5479efa7e389 100644 --- a/src/codegen/src/main/scala/WrapperGenerator.scala +++ b/src/codegen/src/main/scala/WrapperGenerator.scala @@ -235,6 +235,18 @@ class SparklyRWrapperGenerator extends WrapperGenerator { |import(sparklyr) | |export(sdf_transform) + |export(smd_model_downloader) + |export(smd_download_by_name) + |export(smd_local_models) + |export(smd_remote_models) + |export(smd_get_model_name) + |export(smd_get_model_uri) + |export(smd_get_model_type) + |export(smd_get_model_hash) + |export(smd_get_model_size) + |export(smd_get_model_input_node) + |export(smd_get_model_num_layers) + |export(smd_get_model_layer_names) |""".stripMargin) def formatWrapperName(name: String): String = diff --git a/src/downloader/src/main/R/model_downloader.R b/src/downloader/src/main/R/model_downloader.R new file mode 100644 index 000000000000..f53001059c5a --- /dev/null +++ b/src/downloader/src/main/R/model_downloader.R @@ -0,0 +1,157 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +DEFAULT_URL = "https://mmlspark.azureedge.net/datasets/CNTKModels/" + +#' A class for downloading CNTK pretrained models in R. To download all models use the downloadModels +#' function. To browse models from the microsoft server please use remoteModels. +#' +#' Creates the ModelDownloader. +#' +#' @param sc A spark context for interfacing between python and java +#' @param localPath The folder to save models to +#' @param serverURL The location of the model Server, beware this default can change! +#' @param ... Optional arguments; currently unused. +#' +#' @family Model downloader +#' +#' @export +smd_model_downloader <- function(sc, localPath, serverURL=DEFAULT_URL, ...) { + session <- spark_session(sc) + env <- new.env(parent = emptyenv()) + env$model <- "com.microsoft.ml.spark.ModelDownloader" + downloader <- invoke_new(sc, env$model, session, localPath, serverURL) +} + +#' Downloads the model by given name +#' +#' @param smd_model_downloader The model downloader +#' @param name The name of the model +#' @param ... Optional arguments; currently unused. +#' +#' @family Model downloader +#' +#' @export +smd_download_by_name <- function(model_downloader, name, ...) { + model <- invoke(model_downloader, "downloadByName", name) +} + +#' Downloads models stored locally on the filesystem +#' +#' @param smd_model_downloader The model downloader +#' @param ... Optional arguments; currently unused. +#' +#' @family Model downloader +#' +#' @export +smd_local_models <- function(model_downloader, ...) { + model <- invoke(model_downloader, "localModels") +} + +#' Downloads models stored remotely. +#' +#' @param smd_model_downloader The model downloader +#' @param ... Optional arguments; currently unused. +#' +#' @family Model downloader +#' +#' @export +smd_remote_models <- function(model_downloader, ...) { + model <- invoke(model_downloader, "remoteModels") +} + +#' Gets the name of the downloaded model +#' +#' @param model The downloaded model +#' @param ... Optional arguments; currently unused. +#' +#' @family Model downloader +#' +#' @export +smd_get_model_name <- function(model, ...) { + name <- invoke(model, "name") +} + +#' Gets the location of the model's bytes +#' +#' @param model The downloaded model +#' @param ... Optional arguments; currently unused. +#' +#' @family Model downloader +#' +#' @export +smd_get_model_uri <- function(model, ...) { + uri <- invoke(invoke(model, "uri"), "toString") +} + +#' Gets the domain that the model operates on +#' +#' @param model The downloaded model +#' @param ... Optional arguments; currently unused. +#' +#' @family Model downloader +#' +#' @export +smd_get_model_type <- function(model, ...) { + name <- invoke(model, "modelType") +} + +#' Gets the sha256 hash of the models bytes +#' +#' @param model The downloaded model +#' @param ... Optional arguments; currently unused. +#' +#' @family Model downloader +#' +#' @export +smd_get_model_hash <- function(model, ...) { + name <- invoke(model, "hash") +} + +#' Gets the size of the model in bytes +#' +#' @param model The downloaded model +#' @param ... Optional arguments; currently unused. +#' +#' @family Model downloader +#' +#' @export +smd_get_model_size <- function(model, ...) { + name <- invoke(model, "size") +} + +#' Gets the node which represents the input +#' +#' @param model The downloaded model +#' @param ... Optional arguments; currently unused. +#' +#' @family Model downloader +#' +#' @export +smd_get_model_input_node <- function(model, ...) { + name <- invoke(model, "inputNode") +} + +#' Gets the number of layers of the model +#' +#' @param model The downloaded model +#' @param ... Optional arguments; currently unused. +#' +#' @family Model downloader +#' +#' @export +smd_get_model_num_layers <- function(model, ...) { + name <- invoke(model, "numLayers") +} + +#' Gets the names of nodes that represent layers in the network +#' +#' @param model The downloaded model +#' @param ... Optional arguments; currently unused. +#' +#' @family Model downloader +#' +#' @export +smd_get_model_layer_names <- function(model, ...) { + name <- invoke(model, "layerNames") +}