From 9025c4f06effe5275fc4ad178388a5211de13d56 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Sat, 4 Mar 2023 21:27:08 -0800 Subject: [PATCH] [api] Make Batchifier serializable --- api/src/main/java/ai/djl/translate/Batchifier.java | 3 ++- .../main/java/ai/djl/translate/PaddingStackBatchifier.java | 4 +++- .../java/ai/djl/translate/SimplePaddingStackBatchifier.java | 2 ++ api/src/main/java/ai/djl/translate/StackBatchifier.java | 2 ++ 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/api/src/main/java/ai/djl/translate/Batchifier.java b/api/src/main/java/ai/djl/translate/Batchifier.java index e8a49154f1d..32fa23f94fb 100644 --- a/api/src/main/java/ai/djl/translate/Batchifier.java +++ b/api/src/main/java/ai/djl/translate/Batchifier.java @@ -14,6 +14,7 @@ import ai.djl.ndarray.NDList; +import java.io.Serializable; import java.util.Arrays; import java.util.stream.IntStream; @@ -26,7 +27,7 @@ * as axis 0. Another implementation could be a concatenated batchifier for sequence elements that * will concatenate the data elements along the time axis. */ -public interface Batchifier { +public interface Batchifier extends Serializable { Batchifier STACK = new StackBatchifier(); diff --git a/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java b/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java index ab9272bed39..3f3bb1b2d6e 100644 --- a/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java +++ b/api/src/main/java/ai/djl/translate/PaddingStackBatchifier.java @@ -27,9 +27,11 @@ */ public final class PaddingStackBatchifier implements Batchifier { + private static final long serialVersionUID = 1L; + private List arraysToPad; private List dimsToPad; - private List paddingSuppliers; + private transient List paddingSuppliers; private List paddingSizes; private boolean includeValidLengths; diff --git a/api/src/main/java/ai/djl/translate/SimplePaddingStackBatchifier.java b/api/src/main/java/ai/djl/translate/SimplePaddingStackBatchifier.java index cafac13fd64..cab77c97099 100644 --- a/api/src/main/java/ai/djl/translate/SimplePaddingStackBatchifier.java +++ b/api/src/main/java/ai/djl/translate/SimplePaddingStackBatchifier.java @@ -22,6 +22,8 @@ */ public final class SimplePaddingStackBatchifier implements Batchifier { + private static final long serialVersionUID = 1L; + private float padding; /** diff --git a/api/src/main/java/ai/djl/translate/StackBatchifier.java b/api/src/main/java/ai/djl/translate/StackBatchifier.java index d4eb4695022..6d6abcae390 100644 --- a/api/src/main/java/ai/djl/translate/StackBatchifier.java +++ b/api/src/main/java/ai/djl/translate/StackBatchifier.java @@ -28,6 +28,8 @@ */ public class StackBatchifier implements Batchifier { + private static final long serialVersionUID = 1L; + /** {@inheritDoc} */ @Override public NDList batchify(NDList[] inputs) {