From cbff7b6cc695ef134b63af5fd27248ed9aceb6f4 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 19 Apr 2016 23:31:53 +0800 Subject: [PATCH] [scala]operands change for SimpleOp Registry --- .../src/main/scala/ml/dmlc/mxnet/NDArray.scala | 4 ++-- .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 17 +++++++---------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index fbdb5630e4ff..f4aeeee49b35 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -922,7 +922,7 @@ class NDArrayConversions(val value: Float) { } def -(other: NDArray): NDArray = { - NDArray.invokeGenericFunc("_rminus_scalar", Array[Any](other, value))(0) + NDArray.invokeGenericFunc("_rminus_scalar", Array[Any](value, other))(0) } def *(other: NDArray): NDArray = { @@ -930,7 +930,7 @@ class NDArrayConversions(val value: Float) { } def /(other: NDArray): NDArray = { - NDArray.invokeGenericFunc("_rdiv_scalar", Array[Any](other, value))(0) + NDArray.invokeGenericFunc("_rdiv_scalar", Array[Any](value, other))(0) } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index ba9dea6d7a12..4b99adcc0eb9 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -754,8 +754,7 @@ object Symbol { } def pow[@specialized(Int, Float, Double) V](number: V, sym: Symbol): Symbol = { - Symbol.createFromListedSymbols("_PowerScalar")(Array(sym), - Map("scalar" -> number.toString, "scalar_on_left" -> "True")) + Symbol.createFromListedSymbols("_RPowerScalar")(Array(sym), Map("scalar" -> number.toString)) } /** @@ -863,8 +862,7 @@ object Symbol { } def max[@specialized(Int, Float, Double) V](left: V, right: Symbol): Symbol = { - createFromListedSymbols("_MaximumScalar")(Array(right), - Map("scalar" -> left.toString, "scalar_on_left" -> "True")) + createFromListedSymbols("_MaximumScalar")(Array(right), Map("scalar" -> left.toString)) } def min(left: Symbol, right: Symbol): Symbol = { @@ -876,8 +874,7 @@ object Symbol { } def min[@specialized(Int, Float, Double) V](left: V, right: Symbol): Symbol = { - createFromListedSymbols("_MinimumScalar")(Array(right), - Map("scalar" -> left.toString, "scalar_on_left" -> "True")) + createFromListedSymbols("_MinimumScalar")(Array(right), Map("scalar" -> left.toString)) } /** @@ -1528,8 +1525,8 @@ class SymbolConversions[@specialized(Int, Float, Double) V](val value: V) { } def -(other: Symbol): Symbol = { - Symbol.createFromListedSymbols("_MinusScalar")(Array(other), - Map("scalar" -> value.toString, "scalar_on_left" -> "True")) + Symbol.createFromListedSymbols("_RMinusScalar")( + Array(other), Map("scalar" -> value.toString)) } def *(other: Symbol): Symbol = { @@ -1537,7 +1534,7 @@ class SymbolConversions[@specialized(Int, Float, Double) V](val value: V) { } def /(other: Symbol): Symbol = { - Symbol.createFromListedSymbols("_DivScalar")(Array(other), - Map("scalar" -> value.toString, "scalar_on_left" -> "True")) + Symbol.createFromListedSymbols("_RDivScalar")( + Array(other), Map("scalar" -> value.toString)) } }