Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #1 from javelinjs/simpleop
Browse files Browse the repository at this point in the history
[scala] operands change for SimpleOp Registry
  • Loading branch information
tqchen committed Apr 19, 2016
2 parents 19b3e08 + cbff7b6 commit 967cc3c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
4 changes: 2 additions & 2 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -922,15 +922,15 @@ class NDArrayConversions(val value: Float) {
}

def -(other: NDArray): NDArray = {
NDArray.invokeGenericFunc("_rminus_scalar", Array[Any](other, value))(0)
NDArray.invokeGenericFunc("_rminus_scalar", Array[Any](value, other))(0)
}

def *(other: NDArray): NDArray = {
other * value
}

def /(other: NDArray): NDArray = {
NDArray.invokeGenericFunc("_rdiv_scalar", Array[Any](other, value))(0)
NDArray.invokeGenericFunc("_rdiv_scalar", Array[Any](value, other))(0)
}
}

Expand Down
17 changes: 7 additions & 10 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -754,8 +754,7 @@ object Symbol {
}

def pow[@specialized(Int, Float, Double) V](number: V, sym: Symbol): Symbol = {
Symbol.createFromListedSymbols("_PowerScalar")(Array(sym),
Map("scalar" -> number.toString, "scalar_on_left" -> "True"))
Symbol.createFromListedSymbols("_RPowerScalar")(Array(sym), Map("scalar" -> number.toString))
}

/**
Expand Down Expand Up @@ -863,8 +862,7 @@ object Symbol {
}

def max[@specialized(Int, Float, Double) V](left: V, right: Symbol): Symbol = {
createFromListedSymbols("_MaximumScalar")(Array(right),
Map("scalar" -> left.toString, "scalar_on_left" -> "True"))
createFromListedSymbols("_MaximumScalar")(Array(right), Map("scalar" -> left.toString))
}

def min(left: Symbol, right: Symbol): Symbol = {
Expand All @@ -876,8 +874,7 @@ object Symbol {
}

def min[@specialized(Int, Float, Double) V](left: V, right: Symbol): Symbol = {
createFromListedSymbols("_MinimumScalar")(Array(right),
Map("scalar" -> left.toString, "scalar_on_left" -> "True"))
createFromListedSymbols("_MinimumScalar")(Array(right), Map("scalar" -> left.toString))
}

/**
Expand Down Expand Up @@ -1528,16 +1525,16 @@ class SymbolConversions[@specialized(Int, Float, Double) V](val value: V) {
}

def -(other: Symbol): Symbol = {
Symbol.createFromListedSymbols("_MinusScalar")(Array(other),
Map("scalar" -> value.toString, "scalar_on_left" -> "True"))
Symbol.createFromListedSymbols("_RMinusScalar")(
Array(other), Map("scalar" -> value.toString))
}

def *(other: Symbol): Symbol = {
other + value
}

def /(other: Symbol): Symbol = {
Symbol.createFromListedSymbols("_DivScalar")(Array(other),
Map("scalar" -> value.toString, "scalar_on_left" -> "True"))
Symbol.createFromListedSymbols("_RDivScalar")(
Array(other), Map("scalar" -> value.toString))
}
}

0 comments on commit 967cc3c

Please sign in to comment.