Skip to content

Commit

Permalink
bmm and xlogy
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Mar 15, 2023
1 parent bd305b1 commit c9e86c1
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 0 deletions.
16 changes: 16 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,14 @@ default boolean allClose(NDArray other, double rtol, double atol, boolean equalN
*/
NDArray pow(NDArray other);

/**
* Computes this * log(other).
*
* @param other other the other {@code NDArray}
* @return the result {@code NDArray}
*/
NDArray xlogy(NDArray other);

/**
* Adds a number to this {@code NDArray} element-wise in place.
*
Expand Down Expand Up @@ -4292,6 +4300,14 @@ default NDArray argSort(int axis) {
*/
NDArray matMul(NDArray other);

/**
* Batch product matrix of this {@code NDArray} and the other {@code NDArray}.
*
* @param other the other {@code NDArray} to perform matrix product with
* @return the result {@code NDArray}
*/
NDArray batchMatMul(NDArray other);

/**
* Clips (limit) the values in this {@code NDArray}.
*
Expand Down
12 changes: 12 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,18 @@ public NDArray matMul(NDArray other) {
return getAlternativeArray().matMul(other);
}

/** {@inheritDoc} */
@Override
public NDArray batchMatMul(NDArray other) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public NDArray xlogy(NDArray other) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public NDArray clip(Number min, Number max) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,18 @@ public NDArray matMul(NDArray other) {
return manager.invoke("_npi_matmul", new NDArray[] {this, other}, null);
}

/** {@inheritDoc} */
@Override
public NDArray batchMatMul(NDArray other) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public NDArray xlogy(NDArray other) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public NDArray clip(Number min, Number max) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,15 @@ public PtNDArray pow(NDArray other) {
return JniUtils.pow(this, manager.from(other));
}

/** {@inheritDoc} */
@Override
public NDArray xlogy(NDArray other) {
if (isScalar() || other.isScalar()) {
throw new IllegalArgumentException("scalar is not allowed for xlogy()");
}
return JniUtils.xlogy(this, manager.from(other));
}

/** {@inheritDoc} */
@Override
public PtNDArray addi(Number n) {
Expand Down Expand Up @@ -1357,6 +1366,15 @@ public NDArray matMul(NDArray other) {
return JniUtils.matmul(this, manager.from(other));
}

/** {@inheritDoc} */
@Override
public NDArray batchMatMul(NDArray other) {
if (isScalar() || other.isScalar()) {
throw new IllegalArgumentException("scalar is not allowed for batchMatMul()");
}
return JniUtils.bmm(this, manager.from(other));
}

/** {@inheritDoc} */
@Override
public PtNDArray clip(Number min, Number max) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,18 @@ public static PtNDArray matmul(PtNDArray ndArray1, PtNDArray ndArray2) {
PyTorchLibrary.LIB.torchMatmul(ndArray1.getHandle(), ndArray2.getHandle()));
}

public static PtNDArray bmm(PtNDArray ndArray1, PtNDArray ndArray2) {
return new PtNDArray(
ndArray1.getManager(),
PyTorchLibrary.LIB.torchBmm(ndArray1.getHandle(), ndArray2.getHandle()));
}

public static PtNDArray xlogy(PtNDArray ndArray1, PtNDArray ndArray2) {
return new PtNDArray(
ndArray1.getManager(),
PyTorchLibrary.LIB.torchXLogY(ndArray1.getHandle(), ndArray2.getHandle()));
}

public static PtNDArray dot(PtNDArray ndArray1, PtNDArray ndArray2) {
if (ndArray1.getShape().dimension() == 1) {
return new PtNDArray(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ native long torchLinspace(

native long torchMatmul(long self, long other);

native long torchBmm(long self, long other);

native long torchXLogY(long self, long other);

native long torchDot(long self, long other);

native long torchLogicalAnd(long self, long other);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

#include "ai_djl_pytorch_jni_PyTorchLibrary.h"
#include "djl_pytorch_jni_exception.h"
#include "djl_pytorch_utils.h"
Expand Down Expand Up @@ -162,6 +163,26 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMatmul(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchBmm(
JNIEnv* env, jobject jthis, jlong jself, jlong jother) {
API_BEGIN()
const auto* self_ptr = reinterpret_cast<torch::Tensor*>(jself);
const auto* other_ptr = reinterpret_cast<torch::Tensor*>(jother);
const auto* result_ptr = new torch::Tensor(self_ptr->bmm(*other_ptr));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchXLogY(
JNIEnv* env, jobject jthis, jlong jself, jlong jother) {
API_BEGIN()
const auto* self_ptr = reinterpret_cast<torch::Tensor*>(jself);
const auto* other_ptr = reinterpret_cast<torch::Tensor*>(jother);
const auto* result_ptr = new torch::Tensor(self_ptr->xlogy(*other_ptr));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDot(
JNIEnv* env, jobject jthis, jlong jself, jlong jother) {
API_BEGIN()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,16 @@ public NDArray matMul(NDArray other) {
}
}

@Override
public NDArray batchMatMul(NDArray other) {
throw new UnsupportedOperationException();
}

@Override
public NDArray xlogy(NDArray other) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public NDArray clip(Number min, Number max) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Blocks;
import ai.djl.nn.Parameter;
Expand Down Expand Up @@ -486,6 +487,16 @@ public void testMatMul() {
}
}

@Test
public void testBatchMatMul() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
NDArray lhs = manager.randomNormal(0f, 1f, new Shape(10, 5, 4), DataType.FLOAT32);
NDArray rhs = manager.randomNormal(0f, 1f, new Shape(10, 4, 3), DataType.FLOAT32);
NDArray output = lhs.batchMatMul(rhs);
Assert.assertEquals(output.getShape(), new Shape(10, 5, 3));
}
}

@Test
public void testDivScalar() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
Expand Down Expand Up @@ -695,6 +706,17 @@ public void testReversePowScalar() {
}
}

@Test
public void testXLogYNDArray() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
NDArray array = manager.create(new float[] {6, 12, 2, 0});
NDArray other = manager.create(new float[] {3, 1, 2, 3});
NDArray result = array.xlogy(other);
NDArray expected = manager.create(new float[] {6.5917f, 0f, 1.3863f, 0f});
Assertions.assertAlmostEquals(result, expected);
}
}

@Test
public void testBatchDot() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
Expand Down

0 comments on commit c9e86c1

Please sign in to comment.