Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactoring change and merge all docGen in a single place
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Oct 23, 2018
1 parent 07cbdef commit 80171a1
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private[mxnet] object APIDocGenerator{
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
// Generate Java API documentation
hashCollector += org.apache.mxnet.javaapi.APIDocGenerator.absClassGen(FILE_PATH + "javaapi/")
hashCollector += javaClassGen(FILE_PATH + "javaapi/")
val finalHash = hashCollector.mkString("\n")
}

Expand All @@ -54,8 +54,46 @@ private[mxnet] object APIDocGenerator{
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
def fileGen(filePath : String, packageName : String, packageDef : String,
absFuncs : List[String]) : String = {
// Copied from @mdespriee in PR #12489
val apacheLicence =
"""/*
|* Licensed to the Apache Software Foundation (ASF) under one or more
|* contributor license agreements. See the NOTICE file distributed with
|* this work for additional information regarding copyright ownership.
|* The ASF licenses this file to You under the Apache License, Version 2.0
|* (the "License"); you may not use this file except in compliance with
|* the License. You may obtain a copy of the License at
|*
|* http://www.apache.org/licenses/LICENSE-2.0
|*
|* Unless required by applicable law or agreed to in writing, software
|* distributed under the License is distributed on an "AS IS" BASIS,
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|* See the License for the specific language governing permissions and
|* limitations under the License.
|*/
|""".stripMargin
val scalaStyle = "// scalastyle:off"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"

val finalStr =
s"""$apacheLicence
|$scalaStyle
|$packageDef
|$imports
|$absClassDef {
|${absFuncs.mkString("\n")}
|}""".stripMargin
val pw = new PrintWriter(new File(filePath + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
}

def absClassGen(filePath : String, isSymbol : Boolean) : String = {
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
// Defines Operators that should not generated
val notGenerated = Set("Custom")
Expand All @@ -68,19 +106,27 @@ private[mxnet] object APIDocGenerator{
s"$scalaDoc\n$defBody"
})
val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
fileGen(filePath, packageName, packageDef, absFuncs)
}

def javaClassGen(filePath : String) : String = {
val notGenerated = Set("Custom")
val absClassFunctions = getSymbolNDArrayMethods(false, true)
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
.filterNot(ele => notGenerated.contains(ele.name))
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateJavaAPISignature(absClassFunction)
s"$scalaDoc\n$defBody"
})
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
fileGen(filePath, packageName, packageDef, absFuncs)
}

def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
def nonTypeSafeClassGen(filePath : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
val absFuncs = absClassFunctions.map(absClassFunction => {
Expand All @@ -95,17 +141,23 @@ private[mxnet] object APIDocGenerator{
}
})
val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
import java.io._
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
fileGen(filePath, packageName, packageDef, absFuncs)
}

/**
* Some of the C++ type name is not valid in Scala
* such as var and type. This method is to convert
* them into other names to get it passed
* @param in the input String
* @return converted name string
*/
def safetyNameCheck(in : String) : String = {
in match {
case "var" => "vari"
case "type" => "typeOf"
case _ => in
}
}

// Generate ScalaDoc type
Expand All @@ -117,11 +169,7 @@ private[mxnet] object APIDocGenerator{
})
desc += " * </pre>"
val params = func.listOfArgs.map({ absClassArg =>
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
val currArgName = safetyNameCheck(absClassArg.argName)
s" * @param $currArgName\t\t${absClassArg.argDesc}"
})
val returnType = s" * @return ${func.returnType}"
Expand All @@ -135,11 +183,7 @@ private[mxnet] object APIDocGenerator{
def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
var argDef = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
val currArgName = safetyNameCheck(absClassArg.argName)
if (absClassArg.isOptional) {
argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
}
Expand All @@ -159,19 +203,48 @@ private[mxnet] object APIDocGenerator{
s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"
}

def generateJavaAPISignature(func : absClassFunction) : String = {
// scalastyle:off
var argDef = ListBuffer[String]()
var classDef = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = safetyNameCheck(absClassArg.argName)
if (absClassArg.isOptional) {
classDef += s"def set${absClassArg.argName}(${absClassArg.argName} : ${absClassArg.argType}) : ${func.name}BuilderBase"
}
else {
argDef += s"$currArgName : ${absClassArg.argType}"
}
})
classDef += s"def setout(out : NDArray) : ${func.name}BuilderBase"
classDef += s"def invoke() : org.apache.mxnet.javaapi.NDArrayFuncReturn"
val experimentalTag = "@Experimental"
var finalStr = s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : ${func.name}BuilderBase\n"
finalStr += s"abstract class ${func.name}BuilderBase {\n ${classDef.mkString("\n ")}\n}"
finalStr
}


// List and add all the atomic symbol functions to current module.
private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = {
private def getSymbolNDArrayMethods(isSymbol : Boolean, isJava : Boolean = false): List[absClassFunction] = {
val opNames = ListBuffer.empty[String]
val returnType = if (isSymbol) "Symbol" else "NDArray"
val returnHeader = if (isJava) "org.apache.mxnet.javaapi." else "org.apache.mxnet."
_LIB.mxListAllOpNames(opNames)
// TODO: Add '_linalg_', '_sparse_', '_image_' support
// TODO: Add Filter to the same location in case of refactor
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
}).toList.filterNot(_.name.startsWith("_"))
makeAtomicSymbolFunction(opHandle.value, opName, returnHeader + returnType)
}).filterNot(_.name.startsWith("_")).groupBy(_.name.toLowerCase).map(ele => {
// Pattern matching for not generating depreciated method
if (ele._2.length == 1) ele._2.head
else {
if (ele._2.head.name.head.isLower) ele._2.head
else ele._2.last
}
}).toList
}

// Create an atomic symbol function by handle and function name.
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ private[mxnet] object JavaNDArrayMacro {
val argList = argNames zip argTypes map { case (argName, argType) =>
val typeAndOption =
CToScalaUtils.argumentCleaner(argName, argType,
"org.apache.mxnet.javaapi.NDArray", "javaapi.Shape")
"org.apache.mxnet.javaapi.NDArray")
new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
}
new NDArrayFunction(aliasName, argList.toList)
Expand Down
Loading

0 comments on commit 80171a1

Please sign in to comment.