Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into cxt-dns_csr/csr.T_dns
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaotaoChen committed Jun 1, 2018
2 parents 7fad590 + 9854583 commit 5b78686
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 132 deletions.
8 changes: 4 additions & 4 deletions docs/Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def init_git() {

try {
stage('Build Docs') {
node('mxnetlinux-cpu') {
node('restricted-mxnetlinux-cpu') {
ws('workspace/docs') {
init_git()
timeout(time: max_time, unit: 'MINUTES') {
sh "ci/build.py -p ubuntu_cpu /work/runtime_functions.sh build_docs ${params.tags_to_build} ${params.tag_list} ${params.tag_default} ${params.domain}"
archiveArtifacts 'docs/build_version_doc/artifacts.tgz'
build 'website build - test publish'
build 'restricted-website-publish'
}
}
}
Expand All @@ -62,13 +62,13 @@ try {
// set build status to success at the end
currentBuild.result = "SUCCESS"
} catch (caughtError) {
node("mxnetlinux-cpu") {
node("restricted-mxnetlinux-cpu") {
sh "echo caught ${caughtError}"
err = caughtError
currentBuild.result = "FAILURE"
}
} finally {
node("mxnetlinux-cpu") {
node("restricted-mxnetlinux-cpu") {
// Only send email if master failed
if (currentBuild.result == "FAILURE" && env.BRANCH_NAME == "master") {
emailext body: 'Build for MXNet branch ${BRANCH_NAME} has broken. Please view the build at ${BUILD_URL}', replyTo: '${EMAIL}', subject: '[BUILD FAILED] Branch ${BRANCH_NAME} build ${BUILD_NUMBER}', to: '${EMAIL}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ package org.apache.mxnet
* typesafe NDArray API: NDArray.api._
* Main code will be generated during compile time through Macros
*/
object NDArrayAPI {
object NDArrayAPI extends NDArrayAPIBase {
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ package org.apache.mxnet
* typesafe Symbol API: Symbol.api._
* Main code will be generated during compile time through Macros
*/
object SymbolAPI {
object SymbolAPI extends SymbolAPIBase {
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ object ImageClassifierExample {
batch = ListBuffer[String]()
}
}
output += batch.toList
if (batch.length > 0) {
output += batch.toList
}
output.toList
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ object SSDClassifierExample {
batch = ListBuffer[String]()
}
}
output += batch.toList
if (batch.length > 0) {
output += batch.toList
}
output.toList
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ object ImageClassifier {

/**
* Loads a batch of images from a folder
* @param inputImageDirPath Path to a folder of images
* @param inputImagePaths Path to a folder of images
* @return List of buffered images
*/
def loadInputBatch(inputImagePaths: List[String]): Traversable[BufferedImage] = {
Expand Down
24 changes: 24 additions & 0 deletions scala-package/macros/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
</dependency>
</dependencies>


<build>
<plugins>
<plugin>
Expand All @@ -70,6 +71,29 @@
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>1.6.0</version>
<executions>
<execution>
<id>apidoc-generation</id>
<phase>package</phase>
<goals>
<goal>java</goal>
</goals>
</execution>
</executions>
<configuration>
<additionalClasspathElements>
<additionalClasspathElement>${project.parent.basedir}/init/target/classes</additionalClasspathElement>
</additionalClasspathElements>
<arguments>
<argument>${project.parent.basedir}/core/src/main/scala/org/apache/mxnet/</argument>
</arguments>
<mainClass>org.apache.mxnet.APIDocGenerator</mainClass>
</configuration>
</plugin>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet

import org.apache.mxnet.init.Base._
import org.apache.mxnet.utils.CToScalaUtils

import scala.collection.mutable.ListBuffer

/**
* This object will generate the Scala documentation of the new Scala API
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* The code will be executed during Macros stage and file live in Core stage
*/
private[mxnet] object APIDocGenerator{
case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean)
case class absClassFunction(name : String, desc : String,
listOfArgs: List[absClassArg], returnType : String)


def main(args: Array[String]) : Unit = {
val FILE_PATH = args(0)
absClassGen(FILE_PATH, true)
absClassGen(FILE_PATH, false)
}

def absClassGen(FILE_PATH : String, isSymbol : Boolean) : Unit = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")).map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateAPISignature(absClassFunction, isSymbol)
s"$scalaDoc\n$defBody"
})
val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
import java.io._
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
}

// Generate ScalaDoc type
def generateAPIDocFromBackend(func : absClassFunction) : String = {
val desc = func.desc.split("\n").map({ currStr =>
s" * $currStr"
})
val params = func.listOfArgs.map({ absClassArg =>
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
s" * @param $currArgName\t\t${absClassArg.argDesc}"
})
val returnType = s" * @return ${func.returnType}"
s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */"
}

def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
var argDef = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
if (absClassArg.isOptional) {
argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
}
else {
argDef += s"$currArgName : ${absClassArg.argType}"
}
})
var returnType = func.returnType
if (isSymbol) {
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
} else {
returnType = "org.apache.mxnet.NDArrayFuncReturn"
}
s"def ${func.name} (${argDef.mkString(", ")}) : ${returnType}"
}


// List and add all the atomic symbol functions to current module.
private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = {
val opNames = ListBuffer.empty[String]
val returnType = if (isSymbol) "Symbol" else "NDArray"
_LIB.mxListAllOpNames(opNames)
// TODO: Add '_linalg_', '_sparse_', '_image_' support
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
}).toList
}

// Create an atomic symbol function by handle and function name.
private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String)
: absClassFunction = {
val name = new RefString
val desc = new RefString
val keyVarNumArgs = new RefString
val numArgs = new RefInt
val argNames = ListBuffer.empty[String]
val argTypes = ListBuffer.empty[String]
val argDescs = ListBuffer.empty[String]

_LIB.mxSymbolGetAtomicSymbolInfo(
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)

val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})"

val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
val typeAndOption = CToScalaUtils.argumentCleaner(argType, returnType)
new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
new absClassFunction(aliasName, desc.value, argList.toList, returnType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.mxnet

import org.apache.mxnet.init.Base._
import org.apache.mxnet.utils.OperatorBuildUtils
import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}

import scala.annotation.StaticAnnotation
import scala.collection.mutable.ListBuffer
Expand Down Expand Up @@ -133,8 +133,8 @@ private[mxnet] object NDArrayMacro {
impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)"
// scalastyle:on
// Combine and build the function string
val returnType = "org.apache.mxnet.NDArray"
var finalStr = s"def ${ndarrayfunction.name}New"
val returnType = "org.apache.mxnet.NDArrayFuncReturn"
var finalStr = s"def ${ndarrayfunction.name}"
finalStr += s" (${argDef.mkString(",")}) : $returnType"
finalStr += s" = {${impl.mkString("\n")}}"
c.parse(finalStr).asInstanceOf[DefDef]
Expand Down Expand Up @@ -175,63 +175,6 @@ private[mxnet] object NDArrayMacro {
}


// Convert C++ Types to Scala Types
private def typeConversion(in : String, argType : String = "") : String = {
in match {
case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.NDArray"
case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
=> "Array[org.apache.mxnet.NDArray]"
case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat"
case "int" | "intorNone" | "int(non-negative)" => "Int"
case "long" | "long(non-negative)" => "Long"
case "double" | "doubleorNone" => "Double"
case "string" => "String"
case "boolean" | "booleanorNone" => "Boolean"
case "tupleof<float>" | "tupleof<double>" | "ptr" | "" => "Any"
case default => throw new IllegalArgumentException(
s"Invalid type for args: $default, $argType")
}
}


/**
* By default, the argType come from the C++ API is a description more than a single word
* For Example:
* <C++ Type>, <Required/Optional>, <Default=>
* The three field shown above do not usually come at the same time
* This function used the above format to determine if the argument is
* optional, what is it Scala type and possibly pass in a default value
* @param argType Raw arguement Type description
* @return (Scala_Type, isOptional)
*/
private def argumentCleaner(argType : String) : (String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
if (spaceRemoved.charAt(0)== '{') {
val endIdx = spaceRemoved.indexOf('}')
commaRemoved = spaceRemoved.substring(endIdx + 1).split(",")
commaRemoved(0) = "string"
} else {
commaRemoved = spaceRemoved.split(",")
}
// Optional Field
if (commaRemoved.length >= 3) {
// arg: Type, optional, default = Null
require(commaRemoved(1).equals("optional"))
require(commaRemoved(2).startsWith("default="))
(typeConversion(commaRemoved(0), argType), true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
val tempType = typeConversion(commaRemoved(0), argType)
val tempOptional = tempType.equals("org.apache.mxnet.NDArray")
(tempType, tempOptional)
} else {
throw new IllegalArgumentException(
s"Unrecognized arg field: $argType, ${commaRemoved.length}")
}

}


// List and add all the atomic symbol functions to current module.
Expand Down Expand Up @@ -273,7 +216,7 @@ private[mxnet] object NDArrayMacro {
}
// scalastyle:on println
val argList = argNames zip argTypes map { case (argName, argType) =>
val typeAndOption = argumentCleaner(argType)
val typeAndOption = CToScalaUtils.argumentCleaner(argType, "org.apache.mxnet.NDArray")
new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
}
new NDArrayFunction(aliasName, argList.toList)
Expand Down
Loading

0 comments on commit 5b78686

Please sign in to comment.