diff --git a/frontend/server/src/main/java/org/pytorch/serve/ensemble/DagExecutor.java b/frontend/server/src/main/java/org/pytorch/serve/ensemble/DagExecutor.java index 5dcb72c52e..cc7da3ba55 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ensemble/DagExecutor.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ensemble/DagExecutor.java @@ -1,5 +1,6 @@ package org.pytorch.serve.ensemble; +import com.google.common.util.concurrent.ThreadFactoryBuilder; import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -14,6 +15,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.pytorch.serve.archive.model.ModelNotFoundException; @@ -41,8 +43,11 @@ public DagExecutor(Dag dag) { public ArrayList execute(RequestInput input, ArrayList topoSortedList) { CompletionService executorCompletionService = null; + ExecutorService executorService = null; if (topoSortedList == null) { - ExecutorService executorService = Executors.newFixedThreadPool(4); + ThreadFactory namedThreadFactory = + new ThreadFactoryBuilder().setNameFormat("wf-execute-thread-%d").build(); + executorService = Executors.newFixedThreadPool(4, namedThreadFactory); executorCompletionService = new ExecutorCompletionService<>(executorService); } @@ -140,6 +145,9 @@ public ArrayList execute(RequestInput input, ArrayList topoS } } } + if (executorService != null) { + executorService.shutdown(); + } return leafOutputs; } @@ -150,7 +158,7 @@ private NodeOutput invokeModel( InterruptedException { try { - logger.error(String.format("Invoking - %s for attempt %d", nodeName, retryAttempt)); + logger.info(String.format("Invoking - %s for attempt %d", nodeName, retryAttempt)); CompletableFuture respFuture = new CompletableFuture<>(); RestJob job = ApiUtils.addRESTInferenceJob(null, workflowModel.getName(), null, input); job.setResponsePromise(respFuture); diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/WorkflowManager.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/WorkflowManager.java index a391d54b33..116656537e 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/WorkflowManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/WorkflowManager.java @@ -1,5 +1,6 @@ package org.pytorch.serve.workflow; +import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.gson.JsonObject; import com.google.gson.JsonParser; import io.netty.channel.ChannelHandlerContext; @@ -23,6 +24,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.ThreadFactory; import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.model.ModelNotFoundException; import org.pytorch.serve.archive.model.ModelVersionNotFoundException; @@ -51,8 +53,11 @@ public final class WorkflowManager { private static final Logger logger = LoggerFactory.getLogger(WorkflowManager.class); + private final ThreadFactory namedThreadFactory = + new ThreadFactoryBuilder().setNameFormat("wf-manager-thread-%d").build(); private final ExecutorService inferenceExecutorService = - Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); + Executors.newFixedThreadPool( + Runtime.getRuntime().availableProcessors(), namedThreadFactory); private static WorkflowManager workflowManager; private final ConfigManager configManager;