Skip to content

Commit

Permalink
Merge branch 'master' into support_old_torch
Browse files Browse the repository at this point in the history
  • Loading branch information
nijkah authored Apr 9, 2022
2 parents ea7a93c + 2b98375 commit a5d40de
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.pytorch.serve.ensemble;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
Expand All @@ -14,6 +15,7 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.pytorch.serve.archive.model.ModelNotFoundException;
Expand Down Expand Up @@ -41,8 +43,11 @@ public DagExecutor(Dag dag) {
public ArrayList<NodeOutput> execute(RequestInput input, ArrayList<String> topoSortedList) {

CompletionService<NodeOutput> executorCompletionService = null;
ExecutorService executorService = null;
if (topoSortedList == null) {
ExecutorService executorService = Executors.newFixedThreadPool(4);
ThreadFactory namedThreadFactory =
new ThreadFactoryBuilder().setNameFormat("wf-execute-thread-%d").build();
executorService = Executors.newFixedThreadPool(4, namedThreadFactory);
executorCompletionService = new ExecutorCompletionService<>(executorService);
}

Expand Down Expand Up @@ -140,6 +145,9 @@ public ArrayList<NodeOutput> execute(RequestInput input, ArrayList<String> topoS
}
}
}
if (executorService != null) {
executorService.shutdown();
}

return leafOutputs;
}
Expand All @@ -150,7 +158,7 @@ private NodeOutput invokeModel(
InterruptedException {
try {

logger.error(String.format("Invoking - %s for attempt %d", nodeName, retryAttempt));
logger.info(String.format("Invoking - %s for attempt %d", nodeName, retryAttempt));
CompletableFuture<byte[]> respFuture = new CompletableFuture<>();
RestJob job = ApiUtils.addRESTInferenceJob(null, workflowModel.getName(), null, input);
job.setResponsePromise(respFuture);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.pytorch.serve.workflow;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import io.netty.channel.ChannelHandlerContext;
Expand All @@ -23,6 +24,7 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
Expand Down Expand Up @@ -51,8 +53,11 @@
public final class WorkflowManager {
private static final Logger logger = LoggerFactory.getLogger(WorkflowManager.class);

private final ThreadFactory namedThreadFactory =
new ThreadFactoryBuilder().setNameFormat("wf-manager-thread-%d").build();
private final ExecutorService inferenceExecutorService =
Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
Executors.newFixedThreadPool(
Runtime.getRuntime().availableProcessors(), namedThreadFactory);

private static WorkflowManager workflowManager;
private final ConfigManager configManager;
Expand Down

0 comments on commit a5d40de

Please sign in to comment.