Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-984] Java NDArray Documentation Generation (#12835)
Browse files Browse the repository at this point in the history
* cherry pick javaDoc changes

* update NDArray changes

* refactoring change and merge all docGen in a single place

* clean the scalastyle

* take on Piyush nit

* drop the comments
  • Loading branch information
lanking520 authored and nswamy committed Oct 26, 2018
1 parent f759984 commit 5aaa729
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.mxnet.javaapi.DType.DType
import collection.JavaConverters._

@AddJNDArrayAPIs(false)
object NDArray {
object NDArray extends NDArrayBase {
implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new NDArray(nd)

implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ private[mxnet] object APIDocGenerator{
hashCollector += absClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
// Generate Java API documentation
hashCollector += javaClassGen(FILE_PATH + "javaapi/")
val finalHash = hashCollector.mkString("\n")
}

Expand All @@ -52,8 +54,45 @@ private[mxnet] object APIDocGenerator{
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
def fileGen(filePath : String, packageName : String, packageDef : String,
absFuncs : List[String]) : String = {
val apacheLicense =
"""/*
|* Licensed to the Apache Software Foundation (ASF) under one or more
|* contributor license agreements. See the NOTICE file distributed with
|* this work for additional information regarding copyright ownership.
|* The ASF licenses this file to You under the Apache License, Version 2.0
|* (the "License"); you may not use this file except in compliance with
|* the License. You may obtain a copy of the License at
|*
|* http://www.apache.org/licenses/LICENSE-2.0
|*
|* Unless required by applicable law or agreed to in writing, software
|* distributed under the License is distributed on an "AS IS" BASIS,
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|* See the License for the specific language governing permissions and
|* limitations under the License.
|*/
|""".stripMargin
val scalaStyle = "// scalastyle:off"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"

val finalStr =
s"""$apacheLicense
|$scalaStyle
|$packageDef
|$imports
|$absClassDef {
|${absFuncs.mkString("\n")}
|}""".stripMargin
val pw = new PrintWriter(new File(filePath + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
}

def absClassGen(filePath : String, isSymbol : Boolean) : String = {
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
// Defines Operators that should not generated
val notGenerated = Set("Custom")
Expand All @@ -66,19 +105,27 @@ private[mxnet] object APIDocGenerator{
s"$scalaDoc\n$defBody"
})
val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
fileGen(filePath, packageName, packageDef, absFuncs)
}

def javaClassGen(filePath : String) : String = {
val notGenerated = Set("Custom")
val absClassFunctions = getSymbolNDArrayMethods(false, true)
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
.filterNot(ele => notGenerated.contains(ele.name))
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateJavaAPISignature(absClassFunction)
s"$scalaDoc\n$defBody"
})
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
fileGen(filePath, packageName, packageDef, absFuncs)
}

def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
def nonTypeSafeClassGen(filePath : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
val absFuncs = absClassFunctions.map(absClassFunction => {
Expand All @@ -93,17 +140,23 @@ private[mxnet] object APIDocGenerator{
}
})
val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
import java.io._
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
fileGen(filePath, packageName, packageDef, absFuncs)
}

/**
* Some of the C++ type name is not valid in Scala
* such as var and type. This method is to convert
* them into other names to get it passed
* @param in the input String
* @return converted name string
*/
def safetyNameCheck(in : String) : String = {
in match {
case "var" => "vari"
case "type" => "typeOf"
case _ => in
}
}

// Generate ScalaDoc type
Expand All @@ -115,11 +168,7 @@ private[mxnet] object APIDocGenerator{
})
desc += " * </pre>"
val params = func.listOfArgs.map({ absClassArg =>
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
val currArgName = safetyNameCheck(absClassArg.argName)
s" * @param $currArgName\t\t${absClassArg.argDesc}"
})
val returnType = s" * @return ${func.returnType}"
Expand All @@ -133,11 +182,7 @@ private[mxnet] object APIDocGenerator{
def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
var argDef = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
val currArgName = safetyNameCheck(absClassArg.argName)
if (absClassArg.isOptional) {
argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
}
Expand All @@ -157,23 +202,57 @@ private[mxnet] object APIDocGenerator{
s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"
}

def generateJavaAPISignature(func : absClassFunction) : String = {
var argDef = ListBuffer[String]()
var classDef = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = safetyNameCheck(absClassArg.argName)
// scalastyle:off
if (absClassArg.isOptional) {
classDef += s"def set${absClassArg.argName}(${absClassArg.argName} : ${absClassArg.argType}) : ${func.name}BuilderBase"
}
else {
argDef += s"$currArgName : ${absClassArg.argType}"
}
// scalastyle:on
})
classDef += s"def setout(out : NDArray) : ${func.name}BuilderBase"
classDef += s"def invoke() : org.apache.mxnet.javaapi.NDArrayFuncReturn"
val experimentalTag = "@Experimental"
// scalastyle:off
var finalStr = s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : ${func.name}BuilderBase\n"
// scalastyle:on
finalStr += s"abstract class ${func.name}BuilderBase {\n ${classDef.mkString("\n ")}\n}"
finalStr
}


// List and add all the atomic symbol functions to current module.
private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = {
private def getSymbolNDArrayMethods(isSymbol : Boolean,
isJava : Boolean = false): List[absClassFunction] = {
val opNames = ListBuffer.empty[String]
val returnType = if (isSymbol) "Symbol" else "NDArray"
val returnHeader = if (isJava) "org.apache.mxnet.javaapi." else "org.apache.mxnet."
_LIB.mxListAllOpNames(opNames)
// TODO: Add '_linalg_', '_sparse_', '_image_' support
// TODO: Add Filter to the same location in case of refactor
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
}).toList.filterNot(_.name.startsWith("_"))
makeAtomicSymbolFunction(opHandle.value, opName, returnHeader + returnType)
}).filterNot(_.name.startsWith("_")).groupBy(_.name.toLowerCase).map(ele => {
// Pattern matching for not generating depreciated method
if (ele._2.length == 1) ele._2.head
else {
if (ele._2.head.name.head.isLower) ele._2.head
else ele._2.last
}
}).toList
}

// Create an atomic symbol function by handle and function name.
private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String)
private def makeAtomicSymbolFunction(handle: SymbolHandle,
aliasName: String, returnType : String)
: absClassFunction = {
val name = new RefString
val desc = new RefString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ private[mxnet] object JavaNDArrayMacro {
// scalastyle:off
// Combine and build the function string
impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)"
val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")})"
val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")}) extends ${ndarrayfunction.name}BuilderBase"
val classBody = s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() : $returnType = {${impl.mkString("\n")}}"
val classFinal = s"$classDef {$classBody}"
val functionDef = s"def ${ndarrayfunction.name} (${argDef.mkString(",")})"
val functionBody = s"new ${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})"
val functionFinal = s"$functionDef = $functionBody"
val functionFinal = s"$functionDef : ${ndarrayfunction.name}BuilderBase = $functionBody"
// scalastyle:on
functionDefs += c.parse(functionFinal).asInstanceOf[DefDef]
classDefs += c.parse(classFinal).asInstanceOf[ClassDef]
Expand Down Expand Up @@ -195,7 +195,7 @@ private[mxnet] object JavaNDArrayMacro {
val argList = argNames zip argTypes map { case (argName, argType) =>
val typeAndOption =
CToScalaUtils.argumentCleaner(argName, argType,
"org.apache.mxnet.javaapi.NDArray", "javaapi.Shape")
"org.apache.mxnet.javaapi.NDArray")
new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
}
new NDArrayFunction(aliasName, argList.toList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ private[mxnet] object CToScalaUtils {

// Convert C++ Types to Scala Types
def typeConversion(in : String, argType : String = "", argName : String,
returnType : String, shapeType : String = "Shape") : String = {
returnType : String) : String = {
val header = returnType.split("\\.").dropRight(1)
in match {
case "Shape(tuple)" | "ShapeorNone" => s"org.apache.mxnet.$shapeType"
case "Shape(tuple)" | "ShapeorNone" => s"${header.mkString(".")}.Shape"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
=> s"Array[$returnType]"
Expand Down Expand Up @@ -53,7 +54,7 @@ private[mxnet] object CToScalaUtils {
* @return (Scala_Type, isOptional)
*/
def argumentCleaner(argName: String, argType : String,
returnType : String, shapeType : String = "Shape") : (String, Boolean) = {
returnType : String) : (String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
Expand All @@ -73,7 +74,7 @@ private[mxnet] object CToScalaUtils {
s"""expected "default=..." got ${commaRemoved(2)}""")
(typeConversion(commaRemoved(0), argType, argName, returnType), true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
val tempType = typeConversion(commaRemoved(0), argType, argName, returnType, shapeType)
val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
(tempType, tempOptional)
} else {
Expand Down

0 comments on commit 5aaa729

Please sign in to comment.