Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-984] Java NDArray Documentation Generation #12835

Merged
merged 6 commits into from
Oct 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.mxnet.javaapi.DType.DType
import collection.JavaConverters._

@AddJNDArrayAPIs(false)
object NDArray {
object NDArray extends NDArrayBase {
implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new NDArray(nd)

implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ private[mxnet] object APIDocGenerator{
hashCollector += absClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
// Generate Java API documentation
hashCollector += javaClassGen(FILE_PATH + "javaapi/")
val finalHash = hashCollector.mkString("\n")
}

Expand All @@ -52,8 +54,45 @@ private[mxnet] object APIDocGenerator{
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
def fileGen(filePath : String, packageName : String, packageDef : String,
absFuncs : List[String]) : String = {
val apacheLicense =
"""/*
|* Licensed to the Apache Software Foundation (ASF) under one or more
|* contributor license agreements. See the NOTICE file distributed with
|* this work for additional information regarding copyright ownership.
|* The ASF licenses this file to You under the Apache License, Version 2.0
|* (the "License"); you may not use this file except in compliance with
|* the License. You may obtain a copy of the License at
|*
|* http://www.apache.org/licenses/LICENSE-2.0
|*
|* Unless required by applicable law or agreed to in writing, software
|* distributed under the License is distributed on an "AS IS" BASIS,
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|* See the License for the specific language governing permissions and
|* limitations under the License.
|*/
|""".stripMargin
val scalaStyle = "// scalastyle:off"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"

val finalStr =
s"""$apacheLicense
|$scalaStyle
|$packageDef
|$imports
|$absClassDef {
|${absFuncs.mkString("\n")}
|}""".stripMargin
val pw = new PrintWriter(new File(filePath + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
}

def absClassGen(filePath : String, isSymbol : Boolean) : String = {
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
// Defines Operators that should not generated
val notGenerated = Set("Custom")
Expand All @@ -66,19 +105,27 @@ private[mxnet] object APIDocGenerator{
s"$scalaDoc\n$defBody"
})
val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
fileGen(filePath, packageName, packageDef, absFuncs)
}

def javaClassGen(filePath : String) : String = {
val notGenerated = Set("Custom")
val absClassFunctions = getSymbolNDArrayMethods(false, true)
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
.filterNot(ele => notGenerated.contains(ele.name))
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateJavaAPISignature(absClassFunction)
s"$scalaDoc\n$defBody"
})
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
fileGen(filePath, packageName, packageDef, absFuncs)
}

def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
def nonTypeSafeClassGen(filePath : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
val absFuncs = absClassFunctions.map(absClassFunction => {
Expand All @@ -93,17 +140,23 @@ private[mxnet] object APIDocGenerator{
}
})
val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
import java.io._
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
fileGen(filePath, packageName, packageDef, absFuncs)
}

/**
* Some of the C++ type name is not valid in Scala
* such as var and type. This method is to convert
* them into other names to get it passed
* @param in the input String
* @return converted name string
*/
def safetyNameCheck(in : String) : String = {
in match {
case "var" => "vari"
case "type" => "typeOf"
case _ => in
}
}

// Generate ScalaDoc type
Expand All @@ -115,11 +168,7 @@ private[mxnet] object APIDocGenerator{
})
desc += " * </pre>"
val params = func.listOfArgs.map({ absClassArg =>
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
val currArgName = safetyNameCheck(absClassArg.argName)
s" * @param $currArgName\t\t${absClassArg.argDesc}"
})
val returnType = s" * @return ${func.returnType}"
Expand All @@ -133,11 +182,7 @@ private[mxnet] object APIDocGenerator{
def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
var argDef = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
val currArgName = safetyNameCheck(absClassArg.argName)
if (absClassArg.isOptional) {
argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
}
Expand All @@ -157,23 +202,57 @@ private[mxnet] object APIDocGenerator{
s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"
}

def generateJavaAPISignature(func : absClassFunction) : String = {
var argDef = ListBuffer[String]()
var classDef = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = safetyNameCheck(absClassArg.argName)
// scalastyle:off
if (absClassArg.isOptional) {
classDef += s"def set${absClassArg.argName}(${absClassArg.argName} : ${absClassArg.argType}) : ${func.name}BuilderBase"
}
else {
argDef += s"$currArgName : ${absClassArg.argType}"
}
// scalastyle:on
})
classDef += s"def setout(out : NDArray) : ${func.name}BuilderBase"
classDef += s"def invoke() : org.apache.mxnet.javaapi.NDArrayFuncReturn"
val experimentalTag = "@Experimental"
// scalastyle:off
var finalStr = s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : ${func.name}BuilderBase\n"
// scalastyle:on
finalStr += s"abstract class ${func.name}BuilderBase {\n ${classDef.mkString("\n ")}\n}"
finalStr
}


// List and add all the atomic symbol functions to current module.
private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = {
private def getSymbolNDArrayMethods(isSymbol : Boolean,
isJava : Boolean = false): List[absClassFunction] = {
val opNames = ListBuffer.empty[String]
val returnType = if (isSymbol) "Symbol" else "NDArray"
val returnHeader = if (isJava) "org.apache.mxnet.javaapi." else "org.apache.mxnet."
_LIB.mxListAllOpNames(opNames)
// TODO: Add '_linalg_', '_sparse_', '_image_' support
// TODO: Add Filter to the same location in case of refactor
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
}).toList.filterNot(_.name.startsWith("_"))
makeAtomicSymbolFunction(opHandle.value, opName, returnHeader + returnType)
}).filterNot(_.name.startsWith("_")).groupBy(_.name.toLowerCase).map(ele => {
// Pattern matching for not generating depreciated method
if (ele._2.length == 1) ele._2.head
else {
if (ele._2.head.name.head.isLower) ele._2.head
else ele._2.last
}
}).toList
}

// Create an atomic symbol function by handle and function name.
private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String)
private def makeAtomicSymbolFunction(handle: SymbolHandle,
aliasName: String, returnType : String)
: absClassFunction = {
val name = new RefString
val desc = new RefString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ private[mxnet] object JavaNDArrayMacro {
// scalastyle:off
// Combine and build the function string
impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)"
val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")})"
val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")}) extends ${ndarrayfunction.name}BuilderBase"
val classBody = s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() : $returnType = {${impl.mkString("\n")}}"
val classFinal = s"$classDef {$classBody}"
val functionDef = s"def ${ndarrayfunction.name} (${argDef.mkString(",")})"
val functionBody = s"new ${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})"
val functionFinal = s"$functionDef = $functionBody"
val functionFinal = s"$functionDef : ${ndarrayfunction.name}BuilderBase = $functionBody"
// scalastyle:on
functionDefs += c.parse(functionFinal).asInstanceOf[DefDef]
classDefs += c.parse(classFinal).asInstanceOf[ClassDef]
Expand Down Expand Up @@ -195,7 +195,7 @@ private[mxnet] object JavaNDArrayMacro {
val argList = argNames zip argTypes map { case (argName, argType) =>
val typeAndOption =
CToScalaUtils.argumentCleaner(argName, argType,
"org.apache.mxnet.javaapi.NDArray", "javaapi.Shape")
"org.apache.mxnet.javaapi.NDArray")
new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
}
new NDArrayFunction(aliasName, argList.toList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ private[mxnet] object CToScalaUtils {

// Convert C++ Types to Scala Types
def typeConversion(in : String, argType : String = "", argName : String,
returnType : String, shapeType : String = "Shape") : String = {
returnType : String) : String = {
val header = returnType.split("\\.").dropRight(1)
in match {
case "Shape(tuple)" | "ShapeorNone" => s"org.apache.mxnet.$shapeType"
case "Shape(tuple)" | "ShapeorNone" => s"${header.mkString(".")}.Shape"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
=> s"Array[$returnType]"
Expand Down Expand Up @@ -53,7 +54,7 @@ private[mxnet] object CToScalaUtils {
* @return (Scala_Type, isOptional)
*/
def argumentCleaner(argName: String, argType : String,
returnType : String, shapeType : String = "Shape") : (String, Boolean) = {
returnType : String) : (String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
Expand All @@ -73,7 +74,7 @@ private[mxnet] object CToScalaUtils {
s"""expected "default=..." got ${commaRemoved(2)}""")
(typeConversion(commaRemoved(0), argType, argName, returnType), true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
val tempType = typeConversion(commaRemoved(0), argType, argName, returnType, shapeType)
val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
(tempType, tempOptional)
} else {
Expand Down