From 2421c09e7f890695a72944179e0f748ead9148a3 Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 3 Feb 2023 17:46:39 -0800 Subject: [PATCH 01/47] expand model archiver runtime type and add model config file --- model-archiver/model_archiver/arg_parser.py | 9 +++++++++ .../model_archiver/manifest_components/manifest.py | 2 ++ .../model_archiver/manifest_components/model.py | 10 +++++++++- model-archiver/model_archiver/model_packaging.py | 2 ++ model-archiver/model_archiver/model_packaging_utils.py | 1 + 5 files changed, 23 insertions(+), 1 deletion(-) diff --git a/model-archiver/model_archiver/arg_parser.py b/model-archiver/model_archiver/arg_parser.py index d520ba3be2..f3f4096cdf 100644 --- a/model-archiver/model_archiver/arg_parser.py +++ b/model-archiver/model_archiver/arg_parser.py @@ -143,4 +143,13 @@ def export_model_args_parser(): " packages.", ) + parser_export.add_argument( + "-c", + "--config", + required=False, + type=str, + default=None, + help="Path to a yaml file containing model configuration eg. batch_size.", + ) + return parser_export diff --git a/model-archiver/model_archiver/manifest_components/manifest.py b/model-archiver/model_archiver/manifest_components/manifest.py index 9828a69951..2a61fee44e 100644 --- a/model-archiver/model_archiver/manifest_components/manifest.py +++ b/model-archiver/model_archiver/manifest_components/manifest.py @@ -11,6 +11,8 @@ class RuntimeType(Enum): PYTHON = "python" PYTHON3 = "python3" + PTMODELPIPELINE = "ptmodelpipeline" + DEEPSPEED = "deepspeed" class Manifest(object): diff --git a/model-archiver/model_archiver/manifest_components/model.py b/model-archiver/model_archiver/manifest_components/model.py index b985788599..9686608f50 100644 --- a/model-archiver/model_archiver/manifest_components/model.py +++ b/model-archiver/model_archiver/manifest_components/model.py @@ -10,7 +10,7 @@ class Model(object): """ def __init__(self, model_name, serialized_file, handler, model_file=None, model_version=None, - extensions=None, requirements_file=None): + extensions=None, requirements_file=None, config=None): self.model_name = model_name self.serialized_file = None @@ -27,6 +27,11 @@ def __init__(self, model_name, serialized_file, handler, model_file=None, model_ else: self.handler = handler.split("/")[-1] self.requirements_file = requirements_file + if config: + if sys.platform.startswith('win32') and config.find("\\") != -1: + self.config = config.split("\\")[-1] + else: + self.config = config.split("/")[-1] self.model_dict = self.__to_dict__() @@ -52,6 +57,9 @@ def __to_dict__(self): if self.requirements_file: model_dict['requirementsFile'] = self.requirements_file.split("/")[-1] + if self.config: + model_dict['config'] = self.config + return model_dict def __str__(self): diff --git a/model-archiver/model_archiver/model_packaging.py b/model-archiver/model_archiver/model_packaging.py index bcbbfda732..697e2b81b0 100644 --- a/model-archiver/model_archiver/model_packaging.py +++ b/model-archiver/model_archiver/model_packaging.py @@ -22,6 +22,7 @@ def package_model(args, manifest): extra_files = args.extra_files export_file_path = args.export_path requirements_file = args.requirements_file + config = args.config try: ModelExportUtils.validate_inputs(model_name, export_file_path) @@ -37,6 +38,7 @@ def package_model(args, manifest): "handler": handler, "extra_files": extra_files, "requirements-file": requirements_file, + "config": config } model_path = ModelExportUtils.copy_artifacts(model_name, **artifact_files) diff --git a/model-archiver/model_archiver/model_packaging_utils.py b/model-archiver/model_archiver/model_packaging_utils.py index 40eabbd249..454b501e1b 100644 --- a/model-archiver/model_archiver/model_packaging_utils.py +++ b/model-archiver/model_archiver/model_packaging_utils.py @@ -107,6 +107,7 @@ def generate_model(modelargs): handler=modelargs.handler, model_version=modelargs.version, requirements_file=modelargs.requirements_file, + config=modelargs.config ) return model From 9ebb6bf8bfb58238fe9fee3e799ede307dd53671 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 22 Feb 2023 11:46:05 -0800 Subject: [PATCH 02/47] lmi frontend poc init --- frontend/archive/build.gradle | 4 + .../pytorch/serve/archive/model/Manifest.java | 9 + .../serve/archive/model/ModelArchive.java | 33 +++- .../serve/archive/model/s3/BinaryUtils.java | 2 +- .../serve/archive/model/s3/HttpUtils.java | 2 +- .../archive/workflow/WorkflowArchive.java | 2 +- frontend/gradle.properties | 1 + .../org/pytorch/serve/util/ConfigManager.java | 9 + .../java/org/pytorch/serve/wlm/Model.java | 42 ++++- .../pytorch/serve/wlm/WorkLoadManager.java | 32 +++- .../pytorch/serve/wlm/WorkerLifeCycle.java | 16 +- .../org/pytorch/serve/wlm/WorkerThread.java | 173 ++++++++++-------- model-archiver/model_archiver/arg_parser.py | 2 +- .../manifest_components/manifest.py | 3 - .../manifest_components/model.py | 15 +- .../model_archiver/model_packaging.py | 4 +- .../model_archiver/model_packaging_utils.py | 2 +- 17 files changed, 248 insertions(+), 103 deletions(-) diff --git a/frontend/archive/build.gradle b/frontend/archive/build.gradle index cce015aa26..234238590d 100644 --- a/frontend/archive/build.gradle +++ b/frontend/archive/build.gradle @@ -3,6 +3,10 @@ dependencies { api "org.slf4j:slf4j-api:${slf4j_api_version}" api "org.apache.logging.log4j:log4j-slf4j-impl:${slf4j_log4j_version}" api "com.google.code.gson:gson:${gson_version}" + implementation "org.yaml:snakeyaml:${snakeyaml_version}" + + compileOnly "org.projectlombok:lombok:${lombok_version}" + annotationProcessor "org.projectlombok:lombok:${lombok_version}" testImplementation "commons-cli:commons-cli:${commons_cli_version}" testImplementation "org.testng:testng:${testng_version}" diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/Manifest.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/Manifest.java index 18c44a2d9c..9764dd78d3 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/Manifest.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/Manifest.java @@ -64,6 +64,7 @@ public static final class Model { private String handler; private String envelope; private String requirementsFile; + private String configFile; public Model() {} @@ -122,6 +123,14 @@ public String getEnvelope() { public void setEnvelope(String envelope) { this.envelope = envelope; } + + public String getConfigFile() { + return configFile; + } + + public void setConfigFile(String configFile) { + this.configFile = configFile; + } } public enum RuntimeType { diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index 134a931f81..399aba49ca 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -3,8 +3,13 @@ import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Reader; +import java.nio.charset.StandardCharsets; import java.nio.file.FileAlreadyExistsException; import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.List; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; @@ -14,6 +19,8 @@ import org.pytorch.serve.archive.utils.ZipUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.yaml.snakeyaml.Yaml; +import org.yaml.snakeyaml.error.YAMLException; public class ModelArchive { @@ -25,12 +32,14 @@ public class ModelArchive { private String url; private File modelDir; private boolean extracted; + private ModelConfig modelConfig; public ModelArchive(Manifest manifest, String url, File modelDir, boolean extracted) { this.manifest = manifest; this.url = url; this.modelDir = modelDir; this.extracted = extracted; + this.modelConfig = null; } public static ModelArchive downloadModel( @@ -92,7 +101,7 @@ private static ModelArchive load(String url, File dir, boolean extracted) boolean failed = true; try { File manifestFile = new File(dir, "MAR-INF/" + MANIFEST_FILE); - Manifest manifest = null; + Manifest manifest; if (manifestFile.exists()) { manifest = ArchiveUtils.readFile(manifestFile, Manifest.class); } else { @@ -179,4 +188,26 @@ public void clean() { FileUtils.deleteQuietly(modelDir); } } + + public void setModelConfig(ModelConfig modelConfig) { + this.modelConfig = modelConfig; + } + + public ModelConfig getModelConfig() { + if (this.modelConfig != null && manifest.getModel().getConfigFile() != null) { + Path modelConfigFilePath = + Paths.get(modelDir.getAbsolutePath(), manifest.getModel().getConfigFile()); + + Yaml yaml = new Yaml(); + try (Reader r = + new InputStreamReader( + Files.newInputStream(modelConfigFilePath), StandardCharsets.UTF_8)) { + + setModelConfig(yaml.load(r)); + } catch (YAMLException | IOException e) { + logger.error("Failed to parse " + modelConfigFilePath.toString(), e); + } + } + return this.modelConfig; + } } diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/BinaryUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/BinaryUtils.java index a41a58dee0..a0c941b5fe 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/BinaryUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/BinaryUtils.java @@ -36,7 +36,7 @@ public static String toHex(byte[] data) { */ public static byte[] fromHex(String hexData) { byte[] result = new byte[(hexData.length() + 1) / 2]; - String hexNumber = null; + String hexNumber; int stringOffset = 0; int byteOffset = 0; while (stringOffset < hexData.length()) { diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java index 5a469d2513..8c03f78875 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/HttpUtils.java @@ -22,7 +22,7 @@ private HttpUtils() {} public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3SseKmsEnabled) throws IOException { // for a simple GET, we have no body so supply the precomputed 'empty' hash - Map headers = null; + Map headers; if (s3SseKmsEnabled) { String awsAccessKey = System.getenv("AWS_ACCESS_KEY_ID"); String awsSecretKey = System.getenv("AWS_SECRET_ACCESS_KEY"); diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java index fca6879645..929121839e 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java @@ -86,7 +86,7 @@ private static WorkflowArchive load(String url, File dir, boolean extracted) boolean failed = true; try { File manifestFile = new File(dir, "WAR-INF/" + MANIFEST_FILE); - Manifest manifest = null; + Manifest manifest; if (manifestFile.exists()) { manifest = readFile(manifestFile, Manifest.class); } else { diff --git a/frontend/gradle.properties b/frontend/gradle.properties index 4166db7039..42cd7c29ac 100644 --- a/frontend/gradle.properties +++ b/frontend/gradle.properties @@ -12,3 +12,4 @@ snakeyaml_version=1.31 grpc_version=1.50.0 protoc_version=3.18.0 lmax_disruptor_version=3.4.4 +lombok_version=1.18.26 \ No newline at end of file diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 70a21df416..0c2b8d7428 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -98,6 +98,7 @@ public final class ConfigManager { private static final String TS_GRPC_MANAGEMENT_PORT = "grpc_management_port"; private static final String TS_ENABLE_GRPC_SSL = "enable_grpc_ssl"; private static final String TS_INITIAL_WORKER_PORT = "initial_worker_port"; + private static final String TS_INITIAL_DISTRIBUTION_PORT = "initial_distribution_port"; private static final String TS_WORKFLOW_STORE = "workflow_store"; // Configuration which are not documented or enabled through environment variables @@ -803,6 +804,14 @@ public void setInitialWorkerPort(int initialPort) { prop.setProperty(TS_INITIAL_WORKER_PORT, String.valueOf(initialPort)); } + public int getInitialDistributionPort() { + return Integer.parseInt(prop.getProperty(TS_INITIAL_DISTRIBUTION_PORT, "29500")); + } + + public void setInitialDistributionPort(int initialPort) { + prop.setProperty(TS_INITIAL_DISTRIBUTION_PORT, String.valueOf(initialPort)); + } + private void setModelConfig() { String modelConfigStr = prop.getProperty(MODEL_CONFIG, null); Type type = new TypeToken>>() {}.getType(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index 370c3a40cf..b835e83df1 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -2,6 +2,8 @@ import com.google.gson.JsonObject; import java.io.File; +import java.util.ArrayList; +import java.util.Collections; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; @@ -21,12 +23,12 @@ public class Model { public static final String DEFAULT_DATA_QUEUE = "DATA_QUEUE"; - public static final String MIN_WORKERS = "minWorkers"; public static final String MAX_WORKERS = "maxWorkers"; public static final String BATCH_SIZE = "batchSize"; public static final String MAX_BATCH_DELAY = "maxBatchDelay"; public static final String RESPONSE_TIMEOUT = "responseTimeout"; + public static final String PARALLEL_LEVEL = "parallelLevel"; public static final String DEFAULT_VERSION = "defaultVersion"; public static final String MAR_NAME = "marName"; @@ -37,6 +39,8 @@ public class Model { private int maxWorkers; private int batchSize; private int maxBatchDelay; + private int parallelLevel = 1; + private ArrayList gpuIds; private ReentrantLock lock; private int responseTimeout; private ModelVersionName modelVersionName; @@ -51,8 +55,18 @@ public class Model { public Model(ModelArchive modelArchive, int queueSize) { this.modelArchive = modelArchive; - batchSize = 1; - maxBatchDelay = 100; + if (modelArchive != null && modelArchive.getModelConfig() != null) { + minWorkers = modelArchive.getModelConfig().getMinWorkers(); + maxWorkers = modelArchive.getModelConfig().getMaxWorkers(); + batchSize = modelArchive.getModelConfig().getBatchSize(); + maxBatchDelay = modelArchive.getModelConfig().getMaxBatchDelay(); + responseTimeout = modelArchive.getModelConfig().getResponseTimeout(); + parallelLevel = modelArchive.getModelConfig().getParallelLevel(); + gpuIds = modelArchive.getModelConfig().getGpuIds(); + } else { + batchSize = 1; + maxBatchDelay = 100; + } jobsDb = new ConcurrentHashMap<>(); // Always have a queue for data jobsDb.putIfAbsent(DEFAULT_DATA_QUEUE, new LinkedBlockingDeque<>(queueSize)); @@ -73,6 +87,9 @@ public JsonObject getModelState(boolean isDefaultVersion) { modelInfo.addProperty(BATCH_SIZE, getBatchSize()); modelInfo.addProperty(MAX_BATCH_DELAY, getMaxBatchDelay()); modelInfo.addProperty(RESPONSE_TIMEOUT, getResponseTimeout()); + if (parallelLevel > 1) { + modelInfo.addProperty(PARALLEL_LEVEL, parallelLevel); + } return modelInfo; } @@ -83,6 +100,9 @@ public void setModelState(JsonObject modelInfo) { maxBatchDelay = modelInfo.get(MAX_BATCH_DELAY).getAsInt(); responseTimeout = modelInfo.get(RESPONSE_TIMEOUT).getAsInt(); batchSize = modelInfo.get(BATCH_SIZE).getAsInt(); + if (modelInfo.get(PARALLEL_LEVEL) != null) { + parallelLevel = modelInfo.get(PARALLEL_LEVEL).getAsInt(); + } } public String getModelName() { @@ -248,4 +268,20 @@ public int getResponseTimeout() { public void setResponseTimeout(int responseTimeout) { this.responseTimeout = responseTimeout; } + + public ArrayList getGpuIds() { + return this.gpuIds; + } + + public void setGpuIds(ArrayList gpuIds) { + Collections.copy(this.gpuIds, gpuIds); + } + + public void setParallelLevel(int parallelLevel) { + this.parallelLevel = parallelLevel; + } + + public int getParallelLevel() { + return this.parallelLevel; + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index c8f8b1d6a6..b630f28bd0 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -29,6 +29,7 @@ public class WorkLoadManager { private ConfigManager configManager; private EventLoopGroup backendGroup; private AtomicInteger port; + private AtomicInteger distributionPort; private AtomicInteger gpuCounter; private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class); @@ -37,6 +38,7 @@ public WorkLoadManager(ConfigManager configManager, EventLoopGroup backendGroup) this.configManager = configManager; this.backendGroup = backendGroup; this.port = new AtomicInteger(configManager.getInitialWorkerPort()); + this.distributionPort = new AtomicInteger(configManager.getInitialDistributionPort()); this.gpuCounter = new AtomicInteger(0); threadPool = Executors.newCachedThreadPool(); workers = new ConcurrentHashMap<>(); @@ -146,7 +148,7 @@ public CompletableFuture modelChanged( // Need to check worker process here since thread.shutdown() -> lifecycle.exit() // -> This may nullify process object per destroyForcibly doc. - if (workerProcess != null && workerProcess.isAlive()) { + if ((workerProcess != null) && workerProcess.isAlive()) { boolean workerDestroyed = false; try { String cmd = String.format(OSUtils.getKillCmd(), workerProcess.pid()); @@ -193,19 +195,43 @@ private void addThreads( List threads, Model model, int count, CompletableFuture future) { WorkerStateListener listener = new WorkerStateListener(future, count); int maxGpu = configManager.getNumberOfGpu(); + if (maxGpu > 0 && model.getGpuIds() != null) { + maxGpu = model.getGpuIds().size(); + } + int parallelGpuIdx = 0; for (int i = 0; i < count; ++i) { int gpuId = -1; if (maxGpu > 0) { - gpuId = gpuCounter.accumulateAndGet(maxGpu, (prev, maxGpuId) -> ++prev % maxGpuId); + if (model.getParallelLevel() > 1) { + gpuId = + model.getGpuIds() != null + ? model.getGpuIds().get(parallelGpuIdx) + : parallelGpuIdx; + parallelGpuIdx += model.getParallelLevel(); + } else { + if (model.getGpuIds() != null) { + gpuId = model.getGpuIds().get(parallelGpuIdx++ % maxGpu); + } else { + gpuId = + gpuCounter.accumulateAndGet( + maxGpu, (prev, maxGpuId) -> ++prev % maxGpuId); + } + } } BatchAggregator aggregator = new BatchAggregator(model); + int currentPort = + model.getParallelLevel() > 1 + ? configManager.isDebug() + ? distributionPort.get() + : distributionPort.getAndAdd(model.getParallelLevel()) + : configManager.isDebug() ? port.get() : port.getAndIncrement(); WorkerThread thread = new WorkerThread( configManager, backendGroup, - configManager.isDebug() ? port.get() : port.getAndIncrement(), + currentPort, gpuId, model, aggregator, diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index b4928a7143..e4a0a7fd15 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -93,7 +93,10 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup throw new WorkerInitializationException("Failed get TS home directory", e); } - ArrayList argl = new ArrayList(); + ArrayList argl = new ArrayList<>(); + if (model.getParallelLevel() > 1) { + attachRunner(argl, port); + } argl.add(EnvironmentUtils.getPythonRunTime(model)); if (configManager.isCPULauncherEnabled()) { @@ -134,7 +137,7 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup model.getModelArchive().getManifest().getModel().getHandler()); try { - latch = new CountDownLatch(1); + latch = new CountDownLatch(model.getParallelLevel()); String[] args = argl.toArray(new String[argl.size()]); logger.debug("Worker cmdline: {}", argl.toString()); @@ -166,6 +169,15 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup } } + private void attachRunner(ArrayList argl, int port) { + argl.add("torchrun"); + argl.add("--nnodes=1"); + argl.add("--nproc_per_node=" + model.getParallelLevel()); + argl.add("--max_restarts=3"); + argl.add("--rdzv_backend=c10d"); + argl.add("--rdzv_endpoint=localhost:" + port); + } + public synchronized void terminateIOStreams() { if (errReader != null) { logger.warn("terminateIOStreams() threadName={}", errReader.getName()); diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java index cb126452bc..a9aefbf06b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java @@ -16,6 +16,7 @@ import java.net.HttpURLConnection; import java.net.SocketAddress; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.UUID; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CountDownLatch; @@ -61,7 +62,7 @@ public class WorkerThread implements Runnable { private int port; private Model model; - private Channel backendChannel; + private ArrayList backendChannel = new ArrayList<>(); private AtomicBoolean running = new AtomicBoolean(true); private int backoffIdx; @@ -185,10 +186,15 @@ public void run() { long wtStartTime = System.currentTimeMillis(); logger.info("Flushing req. to backend at: " + wtStartTime); - backendChannel.writeAndFlush(req).sync(); + for (int i = 0; backendChannel.size() > 0 && i < model.getParallelLevel(); i++) { + backendChannel.get(i).writeAndFlush(req).sync(); + } + ModelWorkerResponse reply = null; long begin = System.currentTimeMillis(); - ModelWorkerResponse reply = replies.poll(responseTimeout, TimeUnit.SECONDS); + for (int i = 0; i < model.getParallelLevel(); i++) { + reply = replies.poll(responseTimeout, TimeUnit.SECONDS); + } long duration = System.currentTimeMillis() - begin; logger.info("Backend response time: {}", duration); @@ -272,7 +278,9 @@ public void run() { // WorkerThread is running in thread pool, the thread will be assigned to next // Runnable once this worker is finished. If currentThread keep holding the reference // of the thread, currentThread.interrupt() might kill next worker. - backendChannel.disconnect(); + for (int i = 0; backendChannel.size() > 0 && i < model.getParallelLevel(); i++) { + backendChannel.get(i).disconnect(); + } currentThread.set(null); Integer exitValue = lifeCycle.getExitValue(); @@ -309,79 +317,88 @@ private void connect() throws WorkerInitializationException, InterruptedExceptio String modelName = model.getModelName(); String modelVersion = model.getVersion(); setState(WorkerState.WORKER_STARTED, HttpURLConnection.HTTP_OK); - final CountDownLatch latch = new CountDownLatch(1); - + final int parallelLevel = model.getParallelLevel(); + final CountDownLatch latch = new CountDownLatch(parallelLevel); final int responseBufferSize = configManager.getMaxResponseSize(); - try { - Connector connector = new Connector(port); - Bootstrap b = new Bootstrap(); - b.group(backendEventGroup) - .channel(connector.getClientChannel()) - .handler( - new ChannelInitializer() { - @Override - public void initChannel(Channel ch) { - ChannelPipeline p = ch.pipeline(); - p.addLast(ENCODER); - p.addLast(new ModelResponseDecoder(responseBufferSize)); - p.addLast(new WorkerHandler()); - } - }); - - SocketAddress address = connector.getSocketAddress(); - logger.info("Connecting to: {}", address); - backendChannel = b.connect(address).sync().channel(); - backendChannel - .closeFuture() - .addListener( - (ChannelFutureListener) - future -> { - latch.countDown(); - logger.info( - "{} Worker disconnected. {}", getWorkerId(), state); - Thread thread = currentThread.getAndSet(null); - if (thread != null) { - thread.interrupt(); - } - }); - - backendChannel - .newSucceededFuture() - .addListener( - (ChannelFutureListener) - future -> { - // TODO: - // use gpu, batch size in load model command - RequestInput input = - new RequestInput(UUID.randomUUID().toString()); - if (gpuId >= 0) { - input.addParameter( - new InputParameter( - "gpu", String.valueOf(gpuId))); - } - - Job job = - new RestJob( - null, - modelName, - modelVersion, - WorkerCommands.LOAD, - input); - model.addJob(workerId, job); - latch.countDown(); - }); - - if (!latch.await(WORKER_TIMEOUT, TimeUnit.MINUTES)) { - throw new WorkerInitializationException( - "Worker failed to initialize within " + WORKER_TIMEOUT + " mins"); - } - running.set(true); - } catch (Throwable t) { - // https://github.com/netty/netty/issues/2597 - if (t instanceof IOException) { - throw new WorkerInitializationException("Failed to connect to worker.", t); + for (int i = 0; i < parallelLevel; i++) { + try { + Connector connector = new Connector(port + i); + Bootstrap b = new Bootstrap(); + b.group(backendEventGroup) + .channel(connector.getClientChannel()) + .handler( + new ChannelInitializer() { + @Override + public void initChannel(Channel ch) { + ChannelPipeline p = ch.pipeline(); + p.addLast(ENCODER); + p.addLast(new ModelResponseDecoder(responseBufferSize)); + p.addLast(new WorkerHandler()); + } + }); + + SocketAddress address = connector.getSocketAddress(); + logger.info("Connecting to: {}", address); + backendChannel.add(b.connect(address).sync().channel()); + backendChannel + .get(i) + .closeFuture() + .addListener( + (ChannelFutureListener) + future -> { + latch.countDown(); + logger.info( + "{} Worker disconnected. {}", + getWorkerId(), + state); + Thread thread = currentThread.getAndSet(null); + if (thread != null) { + thread.interrupt(); + } + }); + + backendChannel + .get(i) + .newSucceededFuture() + .addListener( + (ChannelFutureListener) + future -> { + // TODO: + // use gpu, batch size in load model command + RequestInput input = + new RequestInput(UUID.randomUUID().toString()); + if (gpuId >= 0) { + input.addParameter( + new InputParameter( + "gpu", String.valueOf(gpuId))); + } + + if (latch.getCount() == parallelLevel) { + + Job job = + new RestJob( + null, + modelName, + modelVersion, + WorkerCommands.LOAD, + input); + model.addJob(workerId, job); + } + latch.countDown(); + }); + + if (!latch.await(WORKER_TIMEOUT, TimeUnit.MINUTES)) { + throw new WorkerInitializationException( + "Worker failed to initialize within " + WORKER_TIMEOUT + " mins"); + } + running.set(true); + } catch (Throwable t) { + // https://github.com/netty/netty/issues/2597 + if (t instanceof IOException) { + throw new WorkerInitializationException("Failed to connect to worker.", t); + } + throw t; } - throw t; } } @@ -404,8 +421,10 @@ public int getPid() { public void shutdown() { running.set(false); setState(WorkerState.WORKER_SCALED_DOWN, HttpURLConnection.HTTP_OK); - if (backendChannel != null) { - backendChannel.close(); + for (int i = 0; backendChannel.size() > 0 && i < model.getParallelLevel(); i++) { + if (backendChannel.get(i) != null) { + backendChannel.get(i).close(); + } } lifeCycle.terminateIOStreams(); Thread thread = currentThread.getAndSet(null); diff --git a/model-archiver/model_archiver/arg_parser.py b/model-archiver/model_archiver/arg_parser.py index f3f4096cdf..7d92964d94 100644 --- a/model-archiver/model_archiver/arg_parser.py +++ b/model-archiver/model_archiver/arg_parser.py @@ -145,7 +145,7 @@ def export_model_args_parser(): parser_export.add_argument( "-c", - "--config", + "--config-file", required=False, type=str, default=None, diff --git a/model-archiver/model_archiver/manifest_components/manifest.py b/model-archiver/model_archiver/manifest_components/manifest.py index 2a61fee44e..8ca037aa5e 100644 --- a/model-archiver/model_archiver/manifest_components/manifest.py +++ b/model-archiver/model_archiver/manifest_components/manifest.py @@ -11,9 +11,6 @@ class RuntimeType(Enum): PYTHON = "python" PYTHON3 = "python3" - PTMODELPIPELINE = "ptmodelpipeline" - DEEPSPEED = "deepspeed" - class Manifest(object): """ diff --git a/model-archiver/model_archiver/manifest_components/model.py b/model-archiver/model_archiver/manifest_components/model.py index 9686608f50..5abcf9bf0a 100644 --- a/model-archiver/model_archiver/manifest_components/model.py +++ b/model-archiver/model_archiver/manifest_components/model.py @@ -10,7 +10,7 @@ class Model(object): """ def __init__(self, model_name, serialized_file, handler, model_file=None, model_version=None, - extensions=None, requirements_file=None, config=None): + extensions=None, requirements_file=None, config_file=None): self.model_name = model_name self.serialized_file = None @@ -27,11 +27,12 @@ def __init__(self, model_name, serialized_file, handler, model_file=None, model_ else: self.handler = handler.split("/")[-1] self.requirements_file = requirements_file - if config: - if sys.platform.startswith('win32') and config.find("\\") != -1: - self.config = config.split("\\")[-1] + self.config_file = None + if config_file: + if sys.platform.startswith('win32') and config_file.find("\\") != -1: + self.config = config_file.split("\\")[-1] else: - self.config = config.split("/")[-1] + self.config = config_file.split("/")[-1] self.model_dict = self.__to_dict__() @@ -57,8 +58,8 @@ def __to_dict__(self): if self.requirements_file: model_dict['requirementsFile'] = self.requirements_file.split("/")[-1] - if self.config: - model_dict['config'] = self.config + if self.config_file: + model_dict['configFile'] = self.config_file return model_dict diff --git a/model-archiver/model_archiver/model_packaging.py b/model-archiver/model_archiver/model_packaging.py index 697e2b81b0..023528fbd1 100644 --- a/model-archiver/model_archiver/model_packaging.py +++ b/model-archiver/model_archiver/model_packaging.py @@ -22,7 +22,7 @@ def package_model(args, manifest): extra_files = args.extra_files export_file_path = args.export_path requirements_file = args.requirements_file - config = args.config + config_file = args.config_file try: ModelExportUtils.validate_inputs(model_name, export_file_path) @@ -38,7 +38,7 @@ def package_model(args, manifest): "handler": handler, "extra_files": extra_files, "requirements-file": requirements_file, - "config": config + "config_file": config_file } model_path = ModelExportUtils.copy_artifacts(model_name, **artifact_files) diff --git a/model-archiver/model_archiver/model_packaging_utils.py b/model-archiver/model_archiver/model_packaging_utils.py index 454b501e1b..57c9986858 100644 --- a/model-archiver/model_archiver/model_packaging_utils.py +++ b/model-archiver/model_archiver/model_packaging_utils.py @@ -107,7 +107,7 @@ def generate_model(modelargs): handler=modelargs.handler, model_version=modelargs.version, requirements_file=modelargs.requirements_file, - config=modelargs.config + config_file=modelargs.config_file ) return model From d50e4ae649f1d1091c735336fe08a93e0bfb01dc Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 22 Feb 2023 15:24:51 -0800 Subject: [PATCH 03/47] add modelConfig.java --- .../serve/archive/model/ModelConfig.java | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java new file mode 100644 index 0000000000..44058c2275 --- /dev/null +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -0,0 +1,72 @@ +package org.pytorch.serve.archive.model; + +import java.util.ArrayList; + +// import lombok.Data; + +// @Data +public class ModelConfig { + private int minWorkers = 1; + private int maxWorkers = 1; + private int batchSize = 1; + private int maxBatchDelay = 100; + private int responseTimeout = 120; + private ArrayList gpuIds; + private int parallelLevel = 1; + + public int getMinWorkers() { + return minWorkers; + } + + public void setMinWorkers(int minWorkers) { + this.minWorkers = minWorkers; + } + + public int getMaxWorkers() { + return maxWorkers; + } + + public void setMaxWorkers(int maxWorkers) { + this.maxWorkers = maxWorkers; + } + + public int getBatchSize() { + return batchSize; + } + + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + + public int getMaxBatchDelay() { + return maxBatchDelay; + } + + public void setMaxBatchDelay(int maxBatchDelay) { + this.maxBatchDelay = maxBatchDelay; + } + + public int getResponseTimeout() { + return responseTimeout; + } + + public void setResponseTimeout(int responseTimeout) { + this.responseTimeout = responseTimeout; + } + + public ArrayList getGpuIds() { + return gpuIds; + } + + public void setGpuIds(ArrayList gpuIds) { + this.gpuIds = gpuIds; + } + + public int getParallelLevel() { + return parallelLevel; + } + + public void setParallelLevel(int parallelLevel) { + this.parallelLevel = parallelLevel; + } +} From 5a69b6264be7d847339b83810253ad293c493448 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 23 Feb 2023 17:58:38 -0800 Subject: [PATCH 04/47] try run_pippy --- examples/Huggingface_Largemodels/Readme.md | 14 +- .../Huggingface_Largemodels/pippy_handler.py | 268 ++++++++++++++++++ ts/model_service_worker.py | 3 +- 3 files changed, 280 insertions(+), 5 deletions(-) create mode 100644 examples/Huggingface_Largemodels/pippy_handler.py diff --git a/examples/Huggingface_Largemodels/Readme.md b/examples/Huggingface_Largemodels/Readme.md index 273c337a33..fd3ef0f3cd 100644 --- a/examples/Huggingface_Largemodels/Readme.md +++ b/examples/Huggingface_Largemodels/Readme.md @@ -1,7 +1,13 @@ # Loading large Huggingface models with constrained resources using accelerate -This document briefs on serving large HG models with limited resource using accelerate. This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint). +This document briefs on serving large HF model with PiPPy. + +### Step 0: Install torchserve from src +```bash +python ts_scripts/install_from_src.py + +``` ### Step 1: Download model Login into huggingface hub with token by running the below command @@ -12,7 +18,7 @@ huggingface-cli login paste the token generated from huggingface hub. ```bash -python Download_model.py --model_name bigscience/bloom-7b1 +python Download_model.py --model_name bigscience/bloom-1b1 ``` The script prints the path where the model is downloaded as below. @@ -28,7 +34,7 @@ Navigate to the path got from the above script. In this example it is ```bash cd model/models--bigscience-bloom-7b1/snapshots/5546055f03398095e385d7dc625e636cc8910bf2/ -zip -r /home/ubuntu/serve/examples/Huggingface_Largemodels//model.zip * +zip -r /home/ubuntu/serve/examples/Huggingface_Largemodels/model.zip * cd - ``` @@ -38,7 +44,7 @@ cd - Navigate up to `Huggingface_Largemodels` directory. ```bash -torch-model-archiver --model-name bloom --version 1.0 --handler custom_handler.py --extra-files model.zip,setup_config.json -r requirements.txt +torch-model-archiver --model-name bloom --version 1.0 --handler pippy_handler.py --extra-files model.zip,setup_config.json -r requirements.txt ``` **__Note__**: Modifying setup_config.json diff --git a/examples/Huggingface_Largemodels/pippy_handler.py b/examples/Huggingface_Largemodels/pippy_handler.py new file mode 100644 index 0000000000..20f4060eb8 --- /dev/null +++ b/examples/Huggingface_Largemodels/pippy_handler.py @@ -0,0 +1,268 @@ +import json +import logging +import os +import zipfile +from abc import ABC + +import torch +import transformers +from transformers import BloomForCausalLM, BloomTokenizerFast + +from ts.torch_handler.base_handler import BaseHandler +import argparse +import inspect +import logging +import os +import time + +import torch +import pippy.fx +from pippy import run_pippy +from pippy.IR import MultiUseParameterConfig, Pipe +from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverInterleaved1F1B, \ + PipelineDriverBase +from pippy.hf import PiPPyHFTracer +from pippy.microbatch import TensorChunkSpec +from pippy import split_on_size_threshold, split_into_equal_size +from transformers import AutoModelForSeq2SeqLM +from transformers import OPTModel, BloomModel +from PIL import Image +import requests +from transformers import AutoFeatureExtractor, RegNetModel +from transformers import OPTForCausalLM + + +logger = logging.getLogger(__name__) +logger.info("Transformers version %s", transformers.__version__) + + +TORCH_DTYPES = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, +} + + +class TransformersSeqClassifierHandler(BaseHandler, ABC): + """ + Transformers handler class for sequence, token classification and question answering. + """ + + def __init__(self): + super(TransformersSeqClassifierHandler, self).__init__() + self.initialized = False + + def initialize(self, ctx): + """In this initialize function, the BERT model is loaded and + the Layer Integrated Gradients Algorithm for Captum Explanations + is initialized here. + Args: + ctx (context): It is a JSON Object containing information + pertaining to the model artefacts parameters. + """ + # parser = argparse.ArgumentParser() + # args = parser.parse_args() + # args.world_size = 4 + # args.gspmd = 1 + + self.manifest = ctx.manifest + properties = ctx.system_properties + model_dir = properties.get("model_dir") + + self.device = torch.device( + "cuda:" + str(properties.get("gpu_id")) + if torch.cuda.is_available() and properties.get("gpu_id") is not None + else "cpu" + ) + # Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode + # further setup config can be added. + with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref: + zip_ref.extractall(model_dir + "/model") + + # read configs for the mode, model_name, etc. from setup_config.json + setup_config_path = os.path.join(model_dir, "setup_config.json") + if os.path.isfile(setup_config_path): + with open(setup_config_path) as setup_config_file: + self.setup_config = json.load(setup_config_file) + else: + logger.warning("Missing the setup_config.json file.") + + torch.manual_seed(42) + + MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE if args.replicate else MultiUseParameterConfig.TRANSMIT + print(f'REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}') + print("Using schedule:", args.schedule) + + device = args.device + self.model = BloomModel.from_pretrained( + model_dir + "/model", use_cache=False) + + + logger.info("********************* model loaded *************************", model_dir) + + # model = BloomModel.from_pretrained("bigscience/bloom-3b", use_cache=False) + + model_config = model.config + + model_config.use_cache = False # don't output `past_key_values` + model.eval() + print(model.config) + print(f"model total number of params = {get_number_of_params(model) // 10 ** 6}M") + + number_of_workers = len(pp_ranks) - pippy.utils.exclude_master + print(f"number_of_workers = {number_of_workers}") + + if args.auto_split == "threshold": + split_policy = split_on_size_threshold(490 * 1e6) + elif args.auto_split == "equal_size": + split_policy = split_into_equal_size(number_of_workers) + + all_worker_ranks = pp_ranks[pippy.utils.exclude_master:pippy.utils.exclude_master + number_of_workers] + chunks = args.chunks or len(all_worker_ranks) + bs = args.batch_size * chunks + seq_length = args.seq_length + + + input_names = ['input_ids'] + sig = inspect.signature(model.forward) + concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} + + print('Instantiating model Pipeline') + model_init_start = time.time() + model_pipe = Pipe.from_tracing(model, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, + output_loss_value_spec=None, split_policy=split_policy + ) + + model_pipe.defer_stage_init(args.device) + + pippy.utils.pp_group_barrier() + + if args.rank!=0: + return + + split_gm_children = list(model_pipe.split_gm.children()) + + pipe_driver: PipelineDriverBase = schedules[args.schedule](model_pipe, chunks, + world_size=len(all_worker_ranks), + all_ranks=all_worker_ranks, + ) + + self.model = pipe_driver + logger.info("Transformer model from path %s loaded successfully", model_dir) + + self.initialized = True + + + def preprocess(self, requests): + """Basic text preprocessing, based on the user's chocie of application mode. + Args: + requests (str): The Input data in the form of text is passed on to the preprocess + function. + Returns: + list : The preprocess function returns a list of Tensor for the size of the word tokens. + """ + input_ids_batch = None + attention_mask_batch = None + for idx, data in enumerate(requests): + input_text = data.get("data") + if input_text is None: + input_text = data.get("body") + if isinstance(input_text, (bytes, bytearray)): + input_text = input_text.decode("utf-8") + + max_length = self.setup_config["max_length"] + logger.info("Received text: '%s'", input_text) + + inputs = self.tokenizer.encode_plus( + input_text, + max_length=int(max_length), + pad_to_max_length=True, + add_special_tokens=True, + return_tensors="pt", + ) + + input_ids = inputs["input_ids"].to(self.device) + attention_mask = inputs["attention_mask"].to(self.device) + # making a batch out of the recieved requests + # attention masks are passed for cases where input tokens are padded. + if input_ids.shape is not None: + if input_ids_batch is None: + input_ids_batch = input_ids + attention_mask_batch = attention_mask + else: + input_ids_batch = torch.cat((input_ids_batch, input_ids), 0) + attention_mask_batch = torch.cat( + (attention_mask_batch, attention_mask), 0 + ) + return (input_ids_batch, attention_mask_batch) + + def inference(self, input_batch): + """Predict the class (or classes) of the received text using the + serialized transformers checkpoint. + Args: + input_batch (list): List of Text Tensors from the pre-process function is passed here + Returns: + list : It returns a list of the predicted value for the input text + """ + (input_ids_batch, _) = input_batch + inferences = [] + input_ids_batch = input_ids_batch.to(self.device) + model_input_dict = {} + model_input_dict["input_ids"]=input_ids_batch + # outputs = self.model.generate( + # input_ids_batch, do_sample=True, max_length=50, top_p=0.95, top_k=60 + # ) + # for i, _ in enumerate(outputs): + # inferences.append( + # self.tokenizer.decode(outputs[i], skip_special_tokens=True) + # ) + output = self.model(**model_input_dict) + print("************** here is the output",type(output)) + logger.info("Generated text: '%s'", inferences) + inference.append(output) + print("Generated text", inferences) + return inferences + + def postprocess(self, inference_output): + """Post Process Function converts the predicted response into Torchserve readable format. + Args: + inference_output (list): It contains the predicted response of the input text. + Returns: + (list): Returns a list of the Predictions and Explanations. + """ + return inference_output + + def handle(self, data, context): + start_time = time.time() + + self.context = context + metrics = self.context.metrics + + #run_pippy(self.initialize, context) + + is_profiler_enabled = os.environ.get("ENABLE_TORCH_PROFILER", None) + if is_profiler_enabled: + if PROFILER_AVAILABLE: + output, _ = self._infer_with_profiler(data=data) + else: + raise RuntimeError( + "Profiler is enabled but current version of torch does not support." + "Install torch>=1.8.1 to use profiler." + ) + else: + if self._is_describe(): + output = [self.describe_handle()] + else: + data_preprocess = self.preprocess(data) + + if not self._is_explain(): + output = self.inference(data_preprocess) + output = self.postprocess(output) + else: + output = self.explain_handle(data_preprocess, data) + + stop_time = time.time() + metrics.add_time( + "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" + ) + return output \ No newline at end of file diff --git a/ts/model_service_worker.py b/ts/model_service_worker.py index 9b18fd6038..2996ae8d0a 100644 --- a/ts/model_service_worker.py +++ b/ts/model_service_worker.py @@ -16,6 +16,7 @@ from ts.metrics.metric_cache_yaml_impl import MetricsCacheYamlImpl from ts.model_loader import ModelLoaderFactory from ts.protocol.otf_message_handler import create_load_model_response, retrieve_msg +from pippy import run_pippy MAX_FAILURE_THRESHOLD = 5 SOCKET_ACCEPT_TIMEOUT = 30.0 @@ -218,7 +219,7 @@ def run_server(self): worker = TorchModelServiceWorker( sock_type, socket_name, host, port, metrics_config ) - worker.run_server() + run_pippy(worker.run_server()) if BENCHMARK: pr.disable() pr.dump_stats("/tmp/tsPythonProfile.prof") From f0cc44552de405e0ceaf8b7d95a915f57ad37e96 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 23 Feb 2023 23:28:56 -0800 Subject: [PATCH 05/47] update archive modification --- .../java/org/pytorch/serve/archive/model/ModelArchive.java | 4 ++-- .../java/org/pytorch/serve/archive/model/ModelConfig.java | 1 + model-archiver/model_archiver/manifest_components/model.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index 399aba49ca..647b2435b7 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -202,8 +202,8 @@ public ModelConfig getModelConfig() { try (Reader r = new InputStreamReader( Files.newInputStream(modelConfigFilePath), StandardCharsets.UTF_8)) { - - setModelConfig(yaml.load(r)); + + this.modelConfig = (ModelConfig) yaml.load(r); } catch (YAMLException | IOException e) { logger.error("Failed to parse " + modelConfigFilePath.toString(), e); } diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index 44058c2275..e790693bd4 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -62,6 +62,7 @@ public void setGpuIds(ArrayList gpuIds) { this.gpuIds = gpuIds; } + public int getParallelLevel() { return parallelLevel; } diff --git a/model-archiver/model_archiver/manifest_components/model.py b/model-archiver/model_archiver/manifest_components/model.py index 5abcf9bf0a..d496bc3fca 100644 --- a/model-archiver/model_archiver/manifest_components/model.py +++ b/model-archiver/model_archiver/manifest_components/model.py @@ -30,9 +30,9 @@ def __init__(self, model_name, serialized_file, handler, model_file=None, model_ self.config_file = None if config_file: if sys.platform.startswith('win32') and config_file.find("\\") != -1: - self.config = config_file.split("\\")[-1] + self.config_file = config_file.split("\\")[-1] else: - self.config = config_file.split("/")[-1] + self.config_file = config_file.split("/")[-1] self.model_dict = self.__to_dict__() From a06598c70dddb51f29083232acb19b55c0a9c031 Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 27 Feb 2023 16:54:31 -0800 Subject: [PATCH 06/47] add model config file in mar file; frontend add torchrun and update connection b/w frontend and backend; update backend handler for pippy integration --- .../Huggingface_Largemodels/model-config.yaml | 5 ++ .../Huggingface_Largemodels/pippy_handler.py | 68 +++++++++++-------- .../serve/archive/model/ModelArchive.java | 23 ++----- .../serve/archive/model/ModelConfig.java | 4 -- .../serve/archive/utils/ArchiveUtils.java | 20 ++++++ .../pytorch/serve/wlm/WorkerLifeCycle.java | 30 ++++++-- .../org/pytorch/serve/wlm/WorkerThread.java | 51 +++++++------- ts/arg_parser.py | 28 ++++++++ ts/model_service_worker.py | 54 +++++++++++---- 9 files changed, 190 insertions(+), 93 deletions(-) create mode 100644 examples/Huggingface_Largemodels/model-config.yaml diff --git a/examples/Huggingface_Largemodels/model-config.yaml b/examples/Huggingface_Largemodels/model-config.yaml new file mode 100644 index 0000000000..297fb1cd4c --- /dev/null +++ b/examples/Huggingface_Largemodels/model-config.yaml @@ -0,0 +1,5 @@ +minWorkers: 1 +maxWorkers: 1 +maxBatchDelay: 100 +responseTimeout: 120 +parallelLevel: 4 diff --git a/examples/Huggingface_Largemodels/pippy_handler.py b/examples/Huggingface_Largemodels/pippy_handler.py index 20f4060eb8..340c9051a2 100644 --- a/examples/Huggingface_Largemodels/pippy_handler.py +++ b/examples/Huggingface_Largemodels/pippy_handler.py @@ -30,6 +30,7 @@ import requests from transformers import AutoFeatureExtractor, RegNetModel from transformers import OPTForCausalLM +import torch.distributed.rpc as rpc logger = logging.getLogger(__name__) @@ -42,6 +43,11 @@ "float64": torch.float64, } +schedules = { + 'FillDrain': PipelineDriverFillDrain, + '1F1B': PipelineDriver1F1B, + 'Interleaved1F1B': PipelineDriverInterleaved1F1B, +} class TransformersSeqClassifierHandler(BaseHandler, ABC): """ @@ -51,6 +57,11 @@ class TransformersSeqClassifierHandler(BaseHandler, ABC): def __init__(self): super(TransformersSeqClassifierHandler, self).__init__() self.initialized = False + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + rpc.init_rpc(f"worker{self.local_rank}", + rank=self.local_rank, + world_size=self.world_size) def initialize(self, ctx): """In this initialize function, the BERT model is loaded and @@ -64,6 +75,8 @@ def initialize(self, ctx): # args = parser.parse_args() # args.world_size = 4 # args.gspmd = 1 + if self.local_rank != 0: + pass self.manifest = ctx.manifest properties = ctx.system_properties @@ -88,12 +101,12 @@ def initialize(self, ctx): logger.warning("Missing the setup_config.json file.") torch.manual_seed(42) + replicate = 0 + schedule = list(schedules.keys())[0] + MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE if replicate else MultiUseParameterConfig.TRANSMIT + print(f'REPLICATE config: {replicate} -> {MULTI_USE_PARAM_CONFIG}') + print("Using schedule:", schedule) - MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE if args.replicate else MultiUseParameterConfig.TRANSMIT - print(f'REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}') - print("Using schedule:", args.schedule) - - device = args.device self.model = BloomModel.from_pretrained( model_dir + "/model", use_cache=False) @@ -102,48 +115,39 @@ def initialize(self, ctx): # model = BloomModel.from_pretrained("bigscience/bloom-3b", use_cache=False) - model_config = model.config + model_config = self.model.config model_config.use_cache = False # don't output `past_key_values` - model.eval() - print(model.config) - print(f"model total number of params = {get_number_of_params(model) // 10 ** 6}M") - - number_of_workers = len(pp_ranks) - pippy.utils.exclude_master - print(f"number_of_workers = {number_of_workers}") + self.model.eval() + print(model_config) + print(f"model total number of params = {self.get_number_of_params(self.model) // 10 ** 6}M") - if args.auto_split == "threshold": - split_policy = split_on_size_threshold(490 * 1e6) - elif args.auto_split == "equal_size": - split_policy = split_into_equal_size(number_of_workers) - - all_worker_ranks = pp_ranks[pippy.utils.exclude_master:pippy.utils.exclude_master + number_of_workers] - chunks = args.chunks or len(all_worker_ranks) - bs = args.batch_size * chunks - seq_length = args.seq_length + split_policy = split_into_equal_size(1) + pp_ranks = [0,1,2,3] + all_worker_ranks = pp_ranks[pippy.utils.exclude_master:pippy.utils.exclude_master + 1] + chunks = 1 + bs = 1 * chunks + seq_length = 16 input_names = ['input_ids'] - sig = inspect.signature(model.forward) + sig = inspect.signature(self.model.forward) concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} print('Instantiating model Pipeline') model_init_start = time.time() - model_pipe = Pipe.from_tracing(model, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, + model_pipe = Pipe.from_tracing(self.model, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, output_loss_value_spec=None, split_policy=split_policy ) - model_pipe.defer_stage_init(args.device) + model_pipe.defer_stage_init(self.device + self.local_rank) pippy.utils.pp_group_barrier() - - if args.rank!=0: - return split_gm_children = list(model_pipe.split_gm.children()) - pipe_driver: PipelineDriverBase = schedules[args.schedule](model_pipe, chunks, - world_size=len(all_worker_ranks), + pipe_driver: PipelineDriverBase = schedules[schedule](model_pipe, chunks, + world_size=self.world_size, all_ranks=all_worker_ranks, ) @@ -152,6 +156,8 @@ def initialize(self, ctx): self.initialized = True + def get_number_of_params(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) def preprocess(self, requests): """Basic text preprocessing, based on the user's chocie of application mode. @@ -219,7 +225,7 @@ def inference(self, input_batch): output = self.model(**model_input_dict) print("************** here is the output",type(output)) logger.info("Generated text: '%s'", inferences) - inference.append(output) + inferences.append(output) print("Generated text", inferences) return inferences @@ -233,6 +239,8 @@ def postprocess(self, inference_output): return inference_output def handle(self, data, context): + if self.local_rank != 0: + pass start_time = time.time() self.context = context diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index 647b2435b7..6c802d8984 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -20,6 +20,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.yaml.snakeyaml.Yaml; +import org.yaml.snakeyaml.constructor.Constructor; import org.yaml.snakeyaml.error.YAMLException; public class ModelArchive { @@ -189,23 +190,13 @@ public void clean() { } } - public void setModelConfig(ModelConfig modelConfig) { - this.modelConfig = modelConfig; - } - public ModelConfig getModelConfig() { - if (this.modelConfig != null && manifest.getModel().getConfigFile() != null) { - Path modelConfigFilePath = - Paths.get(modelDir.getAbsolutePath(), manifest.getModel().getConfigFile()); - - Yaml yaml = new Yaml(); - try (Reader r = - new InputStreamReader( - Files.newInputStream(modelConfigFilePath), StandardCharsets.UTF_8)) { - - this.modelConfig = (ModelConfig) yaml.load(r); - } catch (YAMLException | IOException e) { - logger.error("Failed to parse " + modelConfigFilePath.toString(), e); + if (this.modelConfig == null && manifest.getModel().getConfigFile() != null) { + try { + File configFile = new File(modelDir.getAbsolutePath(), manifest.getModel().getConfigFile()); + this.modelConfig = ArchiveUtils.readYamlFile(configFile, ModelConfig.class); + } catch (InvalidModelException | IOException e) { + logger.error("Failed to parse model config file {}", manifest.getModel().getConfigFile()); } } return this.modelConfig; diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index e790693bd4..6eebba8727 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -2,9 +2,6 @@ import java.util.ArrayList; -// import lombok.Data; - -// @Data public class ModelConfig { private int minWorkers = 1; private int maxWorkers = 1; @@ -62,7 +59,6 @@ public void setGpuIds(ArrayList gpuIds) { this.gpuIds = gpuIds; } - public int getParallelLevel() { return parallelLevel; } diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java index f752337608..6e8bb7c689 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java @@ -12,12 +12,18 @@ import java.nio.charset.StandardCharsets; import java.nio.file.FileAlreadyExistsException; import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.List; import java.util.regex.Pattern; import org.apache.commons.io.FileUtils; import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.model.InvalidModelException; +import org.pytorch.serve.archive.model.ModelConfig; import org.pytorch.serve.archive.s3.HttpUtils; +import org.yaml.snakeyaml.Yaml; +import org.yaml.snakeyaml.constructor.Constructor; +import org.yaml.snakeyaml.error.YAMLException; public final class ArchiveUtils { @@ -39,6 +45,20 @@ public static T readFile(File file, Class type) } } + public static T readYamlFile(File file, Class type) + throws InvalidModelException, IOException { + //Yaml yaml = new Yaml(new Constructor(ModelConfig.class)); + Yaml yaml = new Yaml(new Constructor(type)); + try (Reader r = + new InputStreamReader( + Files.newInputStream(file.toPath()), StandardCharsets.UTF_8)) { + + return yaml.load(r); + } catch (YAMLException e) { + throw new InvalidModelException("Failed to parse model config yaml file.", e); + } + } + public static boolean validateURL(List allowedUrls, String url) throws InvalidArchiveURLException { boolean patternMatch = false; diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index e4a0a7fd15..12345352e2 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -97,7 +97,10 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup if (model.getParallelLevel() > 1) { attachRunner(argl, port); } - argl.add(EnvironmentUtils.getPythonRunTime(model)); + + if (model.getParallelLevel() == 1) { + argl.add(EnvironmentUtils.getPythonRunTime(model)); + } if (configManager.isCPULauncherEnabled()) { launcherArgs = configManager.getCPULauncherArgs(); @@ -130,6 +133,10 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup argl.add("--metrics-config"); argl.add(configManager.getMetricsConfigPath()); + if (model.getParallelLevel() > 1) { + attachPippyArg(argl, port, model.getParallelLevel()); + } + String[] envp = EnvironmentUtils.getEnvString( workingDir.getAbsolutePath(), @@ -171,11 +178,24 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup private void attachRunner(ArrayList argl, int port) { argl.add("torchrun"); - argl.add("--nnodes=1"); + //argl.add("--nnodes=1"); argl.add("--nproc_per_node=" + model.getParallelLevel()); - argl.add("--max_restarts=3"); - argl.add("--rdzv_backend=c10d"); - argl.add("--rdzv_endpoint=localhost:" + port); + argl.add("--max_restarts=0"); + argl.add("--master_addr=localhost"); + argl.add("--master_port=" + port); + //argl.add("--rdzv_backend=c10d"); + //argl.add("--rdzv_endpoint=localhost:" + port); + } + + private void attachPippyArg(ArrayList argl, int port, int parallelLevel) { + argl.add("--master_addr"); + argl.add("localhost"); + argl.add("--master_port"); + argl.add(Integer.toString(port)); + argl.add("--rank"); + argl.add("0"); + argl.add("--world_size"); + argl.add(Integer.toString(parallelLevel)); } public synchronized void terminateIOStreams() { diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java index a9aefbf06b..1308b332cb 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java @@ -186,13 +186,14 @@ public void run() { long wtStartTime = System.currentTimeMillis(); logger.info("Flushing req. to backend at: " + wtStartTime); - for (int i = 0; backendChannel.size() > 0 && i < model.getParallelLevel(); i++) { + int repeats = req.getCommand() == WorkerCommands.LOAD ? model.getParallelLevel() : 1; + for (int i = 0; backendChannel.size() > 0 && i < repeats; i++) { backendChannel.get(i).writeAndFlush(req).sync(); } ModelWorkerResponse reply = null; long begin = System.currentTimeMillis(); - for (int i = 0; i < model.getParallelLevel(); i++) { + for (int i = 0; i < repeats; i++) { reply = replies.poll(responseTimeout, TimeUnit.SECONDS); } @@ -320,8 +321,8 @@ private void connect() throws WorkerInitializationException, InterruptedExceptio final int parallelLevel = model.getParallelLevel(); final CountDownLatch latch = new CountDownLatch(parallelLevel); final int responseBufferSize = configManager.getMaxResponseSize(); - for (int i = 0; i < parallelLevel; i++) { - try { + try { + for (int i = 0; i < parallelLevel; i++) { Connector connector = new Connector(port + i); Bootstrap b = new Bootstrap(); b.group(backendEventGroup) @@ -356,7 +357,6 @@ public void initChannel(Channel ch) { thread.interrupt(); } }); - backendChannel .get(i) .newSucceededFuture() @@ -365,15 +365,14 @@ public void initChannel(Channel ch) { future -> { // TODO: // use gpu, batch size in load model command - RequestInput input = - new RequestInput(UUID.randomUUID().toString()); - if (gpuId >= 0) { - input.addParameter( - new InputParameter( - "gpu", String.valueOf(gpuId))); - } - - if (latch.getCount() == parallelLevel) { + if (latch.getCount() == 1) { + RequestInput input = + new RequestInput(UUID.randomUUID().toString()); + if (gpuId >= 0) { + input.addParameter( + new InputParameter( + "gpu", String.valueOf(gpuId))); + } Job job = new RestJob( @@ -386,20 +385,20 @@ public void initChannel(Channel ch) { } latch.countDown(); }); + } - if (!latch.await(WORKER_TIMEOUT, TimeUnit.MINUTES)) { - throw new WorkerInitializationException( - "Worker failed to initialize within " + WORKER_TIMEOUT + " mins"); - } - running.set(true); - } catch (Throwable t) { - // https://github.com/netty/netty/issues/2597 - if (t instanceof IOException) { - throw new WorkerInitializationException("Failed to connect to worker.", t); - } - throw t; + if (!latch.await(WORKER_TIMEOUT, TimeUnit.MINUTES)) { + throw new WorkerInitializationException( + "Worker failed to initialize within " + WORKER_TIMEOUT + " mins"); } - } + running.set(true); + } catch (Throwable t) { + // https://github.com/netty/netty/issues/2597 + if (t instanceof IOException) { + throw new WorkerInitializationException("Failed to connect to worker.", t); + } + throw t; + } } public boolean isRunning() { diff --git a/ts/arg_parser.py b/ts/arg_parser.py index 0a1d0595e1..44822fe563 100644 --- a/ts/arg_parser.py +++ b/ts/arg_parser.py @@ -128,6 +128,34 @@ def model_service_worker_args(): help="Metrics configuration file", ) + parser.add_argument( + "--master_addr", + dest="master_addr", + type=str, + help="pippy master addr", + ) + + parser.add_argument( + "--master_port", + dest="master_port", + type=int, + help="pippy master port", + ) + + parser.add_argument( + "--rank", + dest="rank", + type=int, + help="pippy rank", + ) + + parser.add_argument( + "--world_size", + dest="world_Size", + type=int, + help="pippy world_size", + ) + return parser @staticmethod diff --git a/ts/model_service_worker.py b/ts/model_service_worker.py index 2996ae8d0a..9d3ef27050 100644 --- a/ts/model_service_worker.py +++ b/ts/model_service_worker.py @@ -19,10 +19,15 @@ from pippy import run_pippy MAX_FAILURE_THRESHOLD = 5 -SOCKET_ACCEPT_TIMEOUT = 30.0 +SOCKET_ACCEPT_TIMEOUT = 300.0 DEBUG = False BENCHMARK = os.getenv("TS_BENCHMARK") BENCHMARK = BENCHMARK in ["True", "true", "TRUE"] +LOCAL_RANK = int(os.environ['LOCAL_RANK']) +WORLD_SIZE = int(os.environ['WORLD_SIZE']) +WORLD_RANK = int(os.environ['RANK']) +import torch.distributed.rpc as rpc +rpc.init_rpc(f"worker{LOCAL_RANK}", rank=LOCAL_RANK, world_size=WORLD_SIZE) class TorchModelServiceWorker(object): @@ -43,15 +48,19 @@ def __init__( if s_type == "unix": if s_name is None: raise ValueError("Wrong arguments passed. No socket name given.") - self.sock_name, self.port = s_name, -1 + s_name_parts = s_name.rsplit('.', 1) + print("part0="+s_name_parts[0]) + print("part1="+s_name_parts[1]) + s_name_new = s_name_parts[0] + '.' + str(int(s_name_parts[1]) + WORLD_RANK) + self.sock_name, self.port = s_name_new, -1 try: - os.remove(s_name) + os.remove(s_name_new) except OSError as e: - if os.path.exists(s_name): + if os.path.exists(s_name_new): raise RuntimeError( - "socket already in use: {}.".format(s_name) + "socket already in use: {}.".format(s_name_new) ) from e - + elif s_type == "tcp": self.sock_name = host_addr if host_addr is not None else "127.0.0.1" if port_num is None: @@ -59,8 +68,9 @@ def __init__( self.port = port_num else: raise ValueError("Incomplete data provided") - - logging.info("Listening on port: %s", s_name) + + #logging.info("Listening on port: %s", s_name) + print("Listening on port: "+ self.sock_name) socket_family = socket.AF_INET if s_type == "tcp" else socket.AF_UNIX self.sock = socket.socket(socket_family, socket.SOCK_STREAM) self.metrics_cache = MetricsCacheYamlImpl(config_file_path=metrics_config) @@ -166,6 +176,8 @@ def run_server(self): Run the backend worker process and listen on a socket :return: """ + print("sock_name="+self.sock_name) + print("sock_port="+str(self.port)) if not DEBUG: self.sock.settimeout(SOCKET_ACCEPT_TIMEOUT) @@ -173,10 +185,17 @@ def run_server(self): if self.sock_type == "unix": self.sock.bind(self.sock_name) + print("binded") + print("self.sock_name="+self.sock_name) else: self.sock.bind((self.sock_name, int(self.port))) - self.sock.listen(1) + # self.sock.listen(1) + self.sock.listen(128) + + print("listened") + print("[PID]"+str(os.getpid())) + print("Torch worker started.") logging.info("[PID]%d", os.getpid()) logging.info("Torch worker started.") logging.info("Python runtime: %s", platform.python_version()) @@ -186,7 +205,8 @@ def run_server(self): # workaround error(35, 'Resource temporarily unavailable') on OSX cl_socket.setblocking(True) - logging.info("Connection accepted: %s.", cl_socket.getsockname()) + #logging.info("Connection accepted: %s.", cl_socket.getsockname()) + print("Connection accepted: "+ cl_socket.getsockname()) self.handle_connection(cl_socket) @@ -206,8 +226,15 @@ def run_server(self): socket_name = args.sock_name sock_type = args.sock_type host = args.host - port = args.port + port = args.port metrics_config = args.metrics_config + args.rank = WORLD_RANK + args.world_size = args.world_size + + + print("LOCAL_RANK="+str(LOCAL_RANK)) + print("WORLD_SIZE="+str(WORLD_SIZE)) + print("WORLD_RANK="+str(WORLD_RANK)) if BENCHMARK: import cProfile @@ -219,7 +246,10 @@ def run_server(self): worker = TorchModelServiceWorker( sock_type, socket_name, host, port, metrics_config ) - run_pippy(worker.run_server()) + + worker.run_server() + + #run_pippy(worker.run_server(), args) if BENCHMARK: pr.disable() pr.dump_stats("/tmp/tsPythonProfile.prof") From 8119bc0c3745620537afe0cb5b771036a82f4455 Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 27 Feb 2023 17:46:20 -0800 Subject: [PATCH 07/47] clean up parameters --- .../org/pytorch/serve/wlm/WorkerLifeCycle.java | 4 ---- ts/arg_parser.py | 14 -------------- ts/model_service_worker.py | 7 ------- 3 files changed, 25 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index 12345352e2..e5925ebf9c 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -192,10 +192,6 @@ private void attachPippyArg(ArrayList argl, int port, int parallelLevel) argl.add("localhost"); argl.add("--master_port"); argl.add(Integer.toString(port)); - argl.add("--rank"); - argl.add("0"); - argl.add("--world_size"); - argl.add(Integer.toString(parallelLevel)); } public synchronized void terminateIOStreams() { diff --git a/ts/arg_parser.py b/ts/arg_parser.py index 44822fe563..0dd2bafd4e 100644 --- a/ts/arg_parser.py +++ b/ts/arg_parser.py @@ -142,20 +142,6 @@ def model_service_worker_args(): help="pippy master port", ) - parser.add_argument( - "--rank", - dest="rank", - type=int, - help="pippy rank", - ) - - parser.add_argument( - "--world_size", - dest="world_Size", - type=int, - help="pippy world_size", - ) - return parser @staticmethod diff --git a/ts/model_service_worker.py b/ts/model_service_worker.py index 9d3ef27050..418ce67a06 100644 --- a/ts/model_service_worker.py +++ b/ts/model_service_worker.py @@ -26,9 +26,6 @@ LOCAL_RANK = int(os.environ['LOCAL_RANK']) WORLD_SIZE = int(os.environ['WORLD_SIZE']) WORLD_RANK = int(os.environ['RANK']) -import torch.distributed.rpc as rpc -rpc.init_rpc(f"worker{LOCAL_RANK}", rank=LOCAL_RANK, world_size=WORLD_SIZE) - class TorchModelServiceWorker(object): """ @@ -228,9 +225,6 @@ def run_server(self): host = args.host port = args.port metrics_config = args.metrics_config - args.rank = WORLD_RANK - args.world_size = args.world_size - print("LOCAL_RANK="+str(LOCAL_RANK)) print("WORLD_SIZE="+str(WORLD_SIZE)) @@ -249,7 +243,6 @@ def run_server(self): worker.run_server() - #run_pippy(worker.run_server(), args) if BENCHMARK: pr.disable() pr.dump_stats("/tmp/tsPythonProfile.prof") From 568720fde886a00afabf4360860153583dd76769 Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri Date: Thu, 2 Mar 2023 05:17:13 +0000 Subject: [PATCH 08/47] fix the start command --- examples/Huggingface_Largemodels/Readme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/Huggingface_Largemodels/Readme.md b/examples/Huggingface_Largemodels/Readme.md index fd3ef0f3cd..636e3b0885 100644 --- a/examples/Huggingface_Largemodels/Readme.md +++ b/examples/Huggingface_Largemodels/Readme.md @@ -64,7 +64,7 @@ mv bloom.mar model_store Update config.properties and start torchserve ```bash -torchserve --start --ncs --ts-config config.properties +torchserve --ncs --start --model-store model_store --models bloom.mar --ts-config config.properties ``` ### Step 5: Run inference From 2b571b73df978c1fe60673c7163583f9b3c291d7 Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri Date: Thu, 2 Mar 2023 05:17:39 +0000 Subject: [PATCH 09/47] updated as per Li's suggestion --- .../Huggingface_Largemodels/config.properties | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/Huggingface_Largemodels/config.properties b/examples/Huggingface_Largemodels/config.properties index f02628fde6..a01bfe2c28 100644 --- a/examples/Huggingface_Largemodels/config.properties +++ b/examples/Huggingface_Largemodels/config.properties @@ -1,10 +1,10 @@ -inference_address=http://0.0.0.0:8080 +nference_address=http://0.0.0.0:8080 management_address=http://0.0.0.0:8081 -metrics_address=http://0.0.0.0:8082 -enable_envvars_config=true +number_of_netty_threads=32 +job_queue_size=1000 +vmargs=-Xmx4g -XX:+ExitOnOutOfMemoryError -XX:+HeapDumpOnOutOfMemoryError +prefer_direct_buffer=True +default_response_timeout=300 +unregister_model_timeout=300 install_py_dep_per_model=true -number_of_gpu=1 -load_models=all -max_response_size=655350000 -default_response_timeout=6000 -model_store=/home/ubuntu/serve/examples/Huggingface_Largemodels/model_store +default_workers_per_model=1 From ebda4eb6beac077ee7143ec673f2e2bc91208df6 Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri Date: Fri, 3 Mar 2023 06:27:34 +0000 Subject: [PATCH 10/47] adding pippy all_compile --- .../Huggingface_Largemodels/pippy_handler.py | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/examples/Huggingface_Largemodels/pippy_handler.py b/examples/Huggingface_Largemodels/pippy_handler.py index 340c9051a2..0eb6fbf2d6 100644 --- a/examples/Huggingface_Largemodels/pippy_handler.py +++ b/examples/Huggingface_Largemodels/pippy_handler.py @@ -76,7 +76,7 @@ def initialize(self, ctx): # args.world_size = 4 # args.gspmd = 1 if self.local_rank != 0: - pass + return self.manifest = ctx.manifest properties = ctx.system_properties @@ -110,6 +110,9 @@ def initialize(self, ctx): self.model = BloomModel.from_pretrained( model_dir + "/model", use_cache=False) + self.tokenizer = BloomTokenizerFast.from_pretrained( + model_dir + "/model", return_tensors="pt" + ) logger.info("********************* model loaded *************************", model_dir) @@ -119,12 +122,11 @@ def initialize(self, ctx): model_config.use_cache = False # don't output `past_key_values` self.model.eval() - print(model_config) - print(f"model total number of params = {self.get_number_of_params(self.model) // 10 ** 6}M") + split_policy = split_into_equal_size(1) pp_ranks = [0,1,2,3] - all_worker_ranks = pp_ranks[pippy.utils.exclude_master:pippy.utils.exclude_master + 1] + all_worker_ranks = list(range(self.world_size)) chunks = 1 bs = 1 * chunks seq_length = 16 @@ -136,28 +138,35 @@ def initialize(self, ctx): print('Instantiating model Pipeline') model_init_start = time.time() - model_pipe = Pipe.from_tracing(self.model, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, - output_loss_value_spec=None, split_policy=split_policy - ) + pipe_driver = pippy.all_compile( + model, + num_ranks=self.world_size, + num_chunks=chunks, + schedule="FillDrain", + split_policy=split_policy, + tracer=PiPPyHFTracer(), + concrete_args=concrete_args, + ) + # model_pipe = Pipe.from_tracing(self.model, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, + # output_loss_value_spec=None, split_policy=split_policy + # ) - model_pipe.defer_stage_init(self.device + self.local_rank) + # model_pipe.defer_stage_init(self.device + self.local_rank) - pippy.utils.pp_group_barrier() + # pippy.utils.pp_group_barrier() - split_gm_children = list(model_pipe.split_gm.children()) + # split_gm_children = list(model_pipe.split_gm.children()) - pipe_driver: PipelineDriverBase = schedules[schedule](model_pipe, chunks, - world_size=self.world_size, - all_ranks=all_worker_ranks, - ) + # pipe_driver: PipelineDriverBase = schedules["FillDrain"](model_pipe, chunks, + # world_size=self.world_size, + # all_ranks=all_worker_ranks, + # ) self.model = pipe_driver logger.info("Transformer model from path %s loaded successfully", model_dir) self.initialized = True - def get_number_of_params(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) def preprocess(self, requests): """Basic text preprocessing, based on the user's chocie of application mode. From eebd8a9772ac1f7a0fa6657788db5be9465c6ae6 Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri Date: Fri, 3 Mar 2023 06:29:57 +0000 Subject: [PATCH 11/47] adding pippy all_compile --- examples/Huggingface_Largemodels/pippy_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/Huggingface_Largemodels/pippy_handler.py b/examples/Huggingface_Largemodels/pippy_handler.py index 0eb6fbf2d6..6cc8f53c52 100644 --- a/examples/Huggingface_Largemodels/pippy_handler.py +++ b/examples/Huggingface_Largemodels/pippy_handler.py @@ -107,7 +107,7 @@ def initialize(self, ctx): print(f'REPLICATE config: {replicate} -> {MULTI_USE_PARAM_CONFIG}') print("Using schedule:", schedule) - self.model = BloomModel.from_pretrained( + model = BloomModel.from_pretrained( model_dir + "/model", use_cache=False) self.tokenizer = BloomTokenizerFast.from_pretrained( @@ -118,10 +118,10 @@ def initialize(self, ctx): # model = BloomModel.from_pretrained("bigscience/bloom-3b", use_cache=False) - model_config = self.model.config + model_config = model.config model_config.use_cache = False # don't output `past_key_values` - self.model.eval() + model.eval() split_policy = split_into_equal_size(1) From 77d7a07f73a2a030f337cc978c1f6242962eacfb Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri Date: Fri, 3 Mar 2023 06:31:49 +0000 Subject: [PATCH 12/47] adding pippy all_compile, fix inference --- examples/Huggingface_Largemodels/pippy_handler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/Huggingface_Largemodels/pippy_handler.py b/examples/Huggingface_Largemodels/pippy_handler.py index 6cc8f53c52..5ee645924f 100644 --- a/examples/Huggingface_Largemodels/pippy_handler.py +++ b/examples/Huggingface_Largemodels/pippy_handler.py @@ -231,7 +231,8 @@ def inference(self, input_batch): # inferences.append( # self.tokenizer.decode(outputs[i], skip_special_tokens=True) # ) - output = self.model(**model_input_dict) + if self.local_rank==0: + output = self.model(**model_input_dict) print("************** here is the output",type(output)) logger.info("Generated text: '%s'", inferences) inferences.append(output) From 049a715fa93afe5dd5a05fffa6486964d7a0f813 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 2 Mar 2023 22:56:46 -0800 Subject: [PATCH 13/47] fix Reply queue is full --- .../pytorch/serve/archive/model/ModelArchive.java | 8 -------- .../java/org/pytorch/serve/wlm/WorkerLifeCycle.java | 12 +++++++----- .../java/org/pytorch/serve/wlm/WorkerThread.java | 2 +- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index 6c802d8984..e07db5d4de 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -3,13 +3,8 @@ import java.io.File; import java.io.IOException; import java.io.InputStream; -import java.io.InputStreamReader; -import java.io.Reader; -import java.nio.charset.StandardCharsets; import java.nio.file.FileAlreadyExistsException; import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; import java.util.List; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; @@ -19,9 +14,6 @@ import org.pytorch.serve.archive.utils.ZipUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.yaml.snakeyaml.Yaml; -import org.yaml.snakeyaml.constructor.Constructor; -import org.yaml.snakeyaml.error.YAMLException; public class ModelArchive { diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index e5925ebf9c..53d3ffe441 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -177,14 +177,16 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup } private void attachRunner(ArrayList argl, int port) { + argl.add("torchrun"); - //argl.add("--nnodes=1"); + argl.add("--nnodes=1"); argl.add("--nproc_per_node=" + model.getParallelLevel()); - argl.add("--max_restarts=0"); + argl.add("--max_restarts=3"); argl.add("--master_addr=localhost"); - argl.add("--master_port=" + port); - //argl.add("--rdzv_backend=c10d"); - //argl.add("--rdzv_endpoint=localhost:" + port); + argl.add("--master_port=" + port); + argl.add("--log_dir=/tmp/torchelastic_ts"); + argl.add("--rdzv_backend=c10d"); + argl.add("--rdzv_id=" + model.getModelName() + "_" + port); } private void attachPippyArg(ArrayList argl, int port, int parallelLevel) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java index 1308b332cb..21a2ccfd30 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java @@ -159,7 +159,7 @@ public WorkerThread( this.listener = listener; startTime = System.currentTimeMillis(); lifeCycle = new WorkerLifeCycle(configManager, model); - replies = new ArrayBlockingQueue<>(1); + replies = new ArrayBlockingQueue<>(model.getParallelLevel()); workerLoadTime = new Metric( getWorkerName(), From b3e4fd9c99b0ae831be20dd82b64655d6f74afd6 Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri Date: Fri, 3 Mar 2023 17:45:44 -0800 Subject: [PATCH 14/47] adding device mapping --- .../Huggingface_Largemodels/pippy_handler.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/examples/Huggingface_Largemodels/pippy_handler.py b/examples/Huggingface_Largemodels/pippy_handler.py index 5ee645924f..06bfc85f36 100644 --- a/examples/Huggingface_Largemodels/pippy_handler.py +++ b/examples/Huggingface_Largemodels/pippy_handler.py @@ -59,9 +59,27 @@ def __init__(self): self.initialized = False self.local_rank = int(os.environ["LOCAL_RANK"]) self.world_size = int(os.environ["WORLD_SIZE"]) + + options = rpc.TensorPipeRpcBackendOptions( + num_worker_threads=512, + rpc_timeout=1800, + _transports=None, + ) + + + # if args.cuda: + # n_devs = self.world_size + # # n_devs = 4 + + # if n_devs > 0: + # dev_id = self.local_rank % n_devs + # for i in range(4): + # options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) + rpc.init_rpc(f"worker{self.local_rank}", rank=self.local_rank, world_size=self.world_size) + # rpc_backend_options=options) def initialize(self, ctx): """In this initialize function, the BERT model is loaded and @@ -82,11 +100,11 @@ def initialize(self, ctx): properties = ctx.system_properties model_dir = properties.get("model_dir") - self.device = torch.device( - "cuda:" + str(properties.get("gpu_id")) - if torch.cuda.is_available() and properties.get("gpu_id") is not None - else "cpu" - ) + # self.device = torch.device( + # "cuda:" + str(properties.get("gpu_id")) + # if torch.cuda.is_available() and properties.get("gpu_id") is not None + # else "cpu" + # ) # Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode # further setup config can be added. with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref: @@ -133,12 +151,12 @@ def initialize(self, ctx): input_names = ['input_ids'] - sig = inspect.signature(self.model.forward) + sig = inspect.signature(model.forward) concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} print('Instantiating model Pipeline') model_init_start = time.time() - pipe_driver = pippy.all_compile( + pipe_driver, stage_mode = pippy.all_compile( model, num_ranks=self.world_size, num_chunks=chunks, @@ -233,6 +251,7 @@ def inference(self, input_batch): # ) if self.local_rank==0: output = self.model(**model_input_dict) + # rpc.shutdown() print("************** here is the output",type(output)) logger.info("Generated text: '%s'", inferences) inferences.append(output) @@ -283,4 +302,4 @@ def handle(self, data, context): metrics.add_time( "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" ) - return output \ No newline at end of file + return output From c1e9f9f89f77974d98de6f597baf842858f6f0a2 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 4 Mar 2023 13:50:53 -0800 Subject: [PATCH 15/47] update device map --- .../Huggingface_Largemodels/pippy_handler.py | 43 ++++++++++--------- ts/model_service_worker.py | 22 +++++----- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/examples/Huggingface_Largemodels/pippy_handler.py b/examples/Huggingface_Largemodels/pippy_handler.py index 06bfc85f36..fe748c9b9e 100644 --- a/examples/Huggingface_Largemodels/pippy_handler.py +++ b/examples/Huggingface_Largemodels/pippy_handler.py @@ -61,25 +61,28 @@ def __init__(self): self.world_size = int(os.environ["WORLD_SIZE"]) options = rpc.TensorPipeRpcBackendOptions( - num_worker_threads=512, - rpc_timeout=1800, - _transports=None, - ) + num_worker_threads=512, + rpc_timeout=1800 + # transports=None, + ) # if args.cuda: - # n_devs = self.world_size - # # n_devs = 4 - - # if n_devs > 0: - # dev_id = self.local_rank % n_devs - # for i in range(4): - # options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) + n_devs = torch.cuda.device_count() + dev_id = self.local_rank % n_devs + for i in range (self.world_size): + options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) + + self.device = f"cuda:{dev_id}" + print( + f"rank = {self.local_rank} pid/device = " + f"{os.getpid()}/{self.device}" + ) rpc.init_rpc(f"worker{self.local_rank}", rank=self.local_rank, - world_size=self.world_size) - # rpc_backend_options=options) + world_size=self.world_size, + rpc_backend_options=options) def initialize(self, ctx): """In this initialize function, the BERT model is loaded and @@ -157,13 +160,13 @@ def initialize(self, ctx): print('Instantiating model Pipeline') model_init_start = time.time() pipe_driver, stage_mode = pippy.all_compile( - model, - num_ranks=self.world_size, - num_chunks=chunks, - schedule="FillDrain", - split_policy=split_policy, - tracer=PiPPyHFTracer(), - concrete_args=concrete_args, + model, + num_ranks=self.world_size, + num_chunks=chunks, + schedule="FillDrain", + split_policy=split_policy, + tracer=PiPPyHFTracer(), + concrete_args=concrete_args, ) # model_pipe = Pipe.from_tracing(self.model, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, # output_loss_value_spec=None, split_policy=split_policy diff --git a/ts/model_service_worker.py b/ts/model_service_worker.py index 418ce67a06..c0f9696b21 100644 --- a/ts/model_service_worker.py +++ b/ts/model_service_worker.py @@ -46,8 +46,8 @@ def __init__( if s_name is None: raise ValueError("Wrong arguments passed. No socket name given.") s_name_parts = s_name.rsplit('.', 1) - print("part0="+s_name_parts[0]) - print("part1="+s_name_parts[1]) + print(f"part0={s_name_parts[0]}, part1={s_name_parts[1]}, pid={str(os.getpid())}") + # print("part1="+s_name_parts[1]) s_name_new = s_name_parts[0] + '.' + str(int(s_name_parts[1]) + WORLD_RANK) self.sock_name, self.port = s_name_new, -1 try: @@ -173,8 +173,8 @@ def run_server(self): Run the backend worker process and listen on a socket :return: """ - print("sock_name="+self.sock_name) - print("sock_port="+str(self.port)) + print(f"sock_name={self.sock_name}, sock_port={str(self.port)}") + # print("sock_port="+str(self.port)) if not DEBUG: self.sock.settimeout(SOCKET_ACCEPT_TIMEOUT) @@ -182,16 +182,16 @@ def run_server(self): if self.sock_type == "unix": self.sock.bind(self.sock_name) - print("binded") - print("self.sock_name="+self.sock_name) + print(f"binded, self.sock_name={self.sock_name}") + # print("self.sock_name="+self.sock_name) else: self.sock.bind((self.sock_name, int(self.port))) # self.sock.listen(1) self.sock.listen(128) - print("listened") - print("[PID]"+str(os.getpid())) + print(f"listened, pid={str(os.getpid())}, LOCAL_RANK={str(LOCAL_RANK)}") + #print("[PID]"+str(os.getpid())) print("Torch worker started.") logging.info("[PID]%d", os.getpid()) logging.info("Torch worker started.") @@ -226,9 +226,9 @@ def run_server(self): port = args.port metrics_config = args.metrics_config - print("LOCAL_RANK="+str(LOCAL_RANK)) - print("WORLD_SIZE="+str(WORLD_SIZE)) - print("WORLD_RANK="+str(WORLD_RANK)) + print(f"LOCAL_RANK={str(LOCAL_RANK)}, WORLD_SIZE={str(WORLD_SIZE)}, WORLD_RANK={str(WORLD_RANK)}") + # print("WORLD_SIZE="+str(WORLD_SIZE)) + # print("WORLD_RANK="+str(WORLD_RANK)) if BENCHMARK: import cProfile From 5d81bd4166c4c45fd140b56ffcf7c48c88160c9c Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 4 Mar 2023 23:05:23 -0800 Subject: [PATCH 16/47] support torchelastic log --- .../pytorch/serve/wlm/WorkerLifeCycle.java | 32 +++++++++++++------ ts/model_service_worker.py | 8 +---- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index 53d3ffe441..795a4041e2 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -9,6 +9,9 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + import org.pytorch.serve.metrics.Metric; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.Connector; @@ -114,7 +117,6 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup argl.add("--ninstances"); argl.add(String.valueOf(this.numWorker)); argl.add("--instance_idx"); - // instance_idx is 0-indexed argl.add(String.valueOf(this.currNumRunningWorkers)); } @@ -177,15 +179,14 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup } private void attachRunner(ArrayList argl, int port) { - + System.setProperty("LOGLEVEL", "INFO"); argl.add("torchrun"); argl.add("--nnodes=1"); argl.add("--nproc_per_node=" + model.getParallelLevel()); argl.add("--max_restarts=3"); - argl.add("--master_addr=localhost"); - argl.add("--master_port=" + port); argl.add("--log_dir=/tmp/torchelastic_ts"); argl.add("--rdzv_backend=c10d"); + argl.add("--rdzv_endpoint=localhost:" + port); argl.add("--rdzv_id=" + model.getModelName() + "_" + port); } @@ -240,6 +241,12 @@ private synchronized void setPort(int port) { } private static final class ReaderThread extends Thread { + private static final Pattern METRIC_PATTERN = Pattern.compile( + "^(INFO > )?(\\[METRICS])(.*)"); + private static final Pattern WORKER_START_PATTERN = Pattern.compile( + "^(INFO > )?(Torch worker started.)$"); + private static final Pattern WORKER_PID_PATTERN = Pattern.compile( + "^(INFO > )?(\\[PID])(\\d+)$"); private InputStream is; private boolean error; @@ -269,8 +276,11 @@ public void run() { if (result == null) { break; } - if (result.startsWith("[METRICS]")) { - Metric parsedMetric = Metric.parse(result.substring("[METRICS]".length())); + + Matcher matcher = METRIC_PATTERN.matcher(result); + if (matcher.matches()) { + logger.info("result={}, pattern={}", result, matcher.group(2)); + Metric parsedMetric = Metric.parse(matcher.group(3)); if (parsedMetric != null) { loggerModelMetrics.info(parsedMetric.toString()); } else { @@ -279,10 +289,14 @@ public void run() { continue; } - if ("Torch worker started.".equals(result)) { + matcher = WORKER_START_PATTERN.matcher(result); + if (matcher.matches()) { lifeCycle.setSuccess(true); - } else if (result.startsWith("[PID]")) { - lifeCycle.setPid(Integer.parseInt(result.substring("[PID]".length()))); + } else { + matcher = WORKER_PID_PATTERN.matcher(result); + if (matcher.matches()) { + lifeCycle.setPid(Integer.parseInt(matcher.group(3))); + } } if (error) { loggerModelOutput.warn(result); diff --git a/ts/model_service_worker.py b/ts/model_service_worker.py index c0f9696b21..639fcbd1bb 100644 --- a/ts/model_service_worker.py +++ b/ts/model_service_worker.py @@ -47,7 +47,6 @@ def __init__( raise ValueError("Wrong arguments passed. No socket name given.") s_name_parts = s_name.rsplit('.', 1) print(f"part0={s_name_parts[0]}, part1={s_name_parts[1]}, pid={str(os.getpid())}") - # print("part1="+s_name_parts[1]) s_name_new = s_name_parts[0] + '.' + str(int(s_name_parts[1]) + WORLD_RANK) self.sock_name, self.port = s_name_new, -1 try: @@ -174,7 +173,6 @@ def run_server(self): :return: """ print(f"sock_name={self.sock_name}, sock_port={str(self.port)}") - # print("sock_port="+str(self.port)) if not DEBUG: self.sock.settimeout(SOCKET_ACCEPT_TIMEOUT) @@ -183,7 +181,6 @@ def run_server(self): if self.sock_type == "unix": self.sock.bind(self.sock_name) print(f"binded, self.sock_name={self.sock_name}") - # print("self.sock_name="+self.sock_name) else: self.sock.bind((self.sock_name, int(self.port))) @@ -191,8 +188,7 @@ def run_server(self): self.sock.listen(128) print(f"listened, pid={str(os.getpid())}, LOCAL_RANK={str(LOCAL_RANK)}") - #print("[PID]"+str(os.getpid())) - print("Torch worker started.") + #print("Torch worker started.") logging.info("[PID]%d", os.getpid()) logging.info("Torch worker started.") logging.info("Python runtime: %s", platform.python_version()) @@ -227,8 +223,6 @@ def run_server(self): metrics_config = args.metrics_config print(f"LOCAL_RANK={str(LOCAL_RANK)}, WORLD_SIZE={str(WORLD_SIZE)}, WORLD_RANK={str(WORLD_RANK)}") - # print("WORLD_SIZE="+str(WORLD_SIZE)) - # print("WORLD_RANK="+str(WORLD_RANK)) if BENCHMARK: import cProfile From 3aff4ff2aa8315f7bf02280b7688d9616a8bdc43 Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 5 Mar 2023 19:14:08 -0800 Subject: [PATCH 17/47] add torchrun config --- examples/Huggingface_Largemodels/model-config.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/Huggingface_Largemodels/model-config.yaml b/examples/Huggingface_Largemodels/model-config.yaml index 297fb1cd4c..7e33b6a3b8 100644 --- a/examples/Huggingface_Largemodels/model-config.yaml +++ b/examples/Huggingface_Largemodels/model-config.yaml @@ -3,3 +3,11 @@ maxWorkers: 1 maxBatchDelay: 100 responseTimeout: 120 parallelLevel: 4 +parallelType: pp # pp: pipeline parallel; tp: tensor parallel; tp+pp + +torchrun: + logLevel: INFO + +pippy: + rpc_timeout: 1800 + From afc0ed05b28d06941f7cca97ff1848d04c7f8c0e Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 9 Mar 2023 00:29:56 -0800 Subject: [PATCH 18/47] fixed initialize and unit test --- .../Huggingface_Largemodels/pippy_handler.py | 25 ++--- .../serve/archive/model/ModelArchive.java | 7 +- .../serve/archive/model/ModelConfig.java | 72 +++++++++++-- .../serve/archive/utils/ArchiveUtils.java | 9 +- .../java/org/pytorch/serve/ModelServer.java | 54 +++++++--- .../serve/grpcimpl/ManagementImpl.java | 21 ++-- .../serve/http/HttpRequestHandlerChain.java | 3 +- .../rest/ApiDescriptionRequestHandler.java | 4 +- .../api/rest/InferenceRequestHandler.java | 4 +- .../api/rest/ManagementRequestHandler.java | 9 +- .../rest/PrometheusMetricsRequestHandler.java | 4 +- .../serve/snapshot/SnapshotManager.java | 6 +- .../java/org/pytorch/serve/util/ApiUtils.java | 25 +++-- .../java/org/pytorch/serve/wlm/Model.java | 64 ++++++++--- .../org/pytorch/serve/wlm/ModelManager.java | 101 +++++++++++++----- .../pytorch/serve/wlm/WorkLoadManager.java | 17 ++- .../pytorch/serve/wlm/WorkerLifeCycle.java | 24 ++--- .../org/pytorch/serve/wlm/WorkerThread.java | 95 ++++++++-------- .../http/WorkflowInferenceRequestHandler.java | 68 ++++++------ .../api/http/WorkflowMgmtRequestHandler.java | 30 +++--- ts/model_service_worker.py | 20 ++-- 21 files changed, 424 insertions(+), 238 deletions(-) diff --git a/examples/Huggingface_Largemodels/pippy_handler.py b/examples/Huggingface_Largemodels/pippy_handler.py index fe748c9b9e..b726e4a8a8 100644 --- a/examples/Huggingface_Largemodels/pippy_handler.py +++ b/examples/Huggingface_Largemodels/pippy_handler.py @@ -4,20 +4,14 @@ import zipfile from abc import ABC -import torch -import transformers -from transformers import BloomForCausalLM, BloomTokenizerFast - -from ts.torch_handler.base_handler import BaseHandler import argparse import inspect import logging import os import time -import torch import pippy.fx -from pippy import run_pippy +#from pippy import run_pippy from pippy.IR import MultiUseParameterConfig, Pipe from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverInterleaved1F1B, \ PipelineDriverBase @@ -32,11 +26,16 @@ from transformers import OPTForCausalLM import torch.distributed.rpc as rpc +import torch +import transformers +from transformers import BloomForCausalLM, BloomTokenizerFast + +from ts.torch_handler.base_handler import BaseHandler logger = logging.getLogger(__name__) logger.info("Transformers version %s", transformers.__version__) - +PIPPY_VERBOSITY = os.environ.get("PIPPY_VERBOSITY", "DEBUG") TORCH_DTYPES = { "float16": torch.float16, "float32": torch.float32, @@ -65,12 +64,14 @@ def __init__(self): rpc_timeout=1800 # transports=None, ) - + # if args.cuda: n_devs = torch.cuda.device_count() + print(f"n_devs={n_devs}") dev_id = self.local_rank % n_devs for i in range (self.world_size): + print(f"worker{i}, {dev_id}: {i % n_devs}") options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) self.device = f"cuda:{dev_id}" @@ -96,8 +97,8 @@ def initialize(self, ctx): # args = parser.parse_args() # args.world_size = 4 # args.gspmd = 1 - if self.local_rank != 0: - return + #if self.local_rank != 0: + # return self.manifest = ctx.manifest properties = ctx.system_properties @@ -145,7 +146,7 @@ def initialize(self, ctx): model.eval() - split_policy = split_into_equal_size(1) + split_policy = split_into_equal_size(self.world_size) pp_ranks = [0,1,2,3] all_worker_ranks = list(range(self.world_size)) chunks = 1 diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index e07db5d4de..dc77315180 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -185,10 +185,13 @@ public void clean() { public ModelConfig getModelConfig() { if (this.modelConfig == null && manifest.getModel().getConfigFile() != null) { try { - File configFile = new File(modelDir.getAbsolutePath(), manifest.getModel().getConfigFile()); + File configFile = + new File(modelDir.getAbsolutePath(), manifest.getModel().getConfigFile()); this.modelConfig = ArchiveUtils.readYamlFile(configFile, ModelConfig.class); } catch (InvalidModelException | IOException e) { - logger.error("Failed to parse model config file {}", manifest.getModel().getConfigFile()); + logger.error( + "Failed to parse model config file {}", + manifest.getModel().getConfigFile()); } } return this.modelConfig; diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index 6eebba8727..feb3a06d6d 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -3,13 +3,15 @@ import java.util.ArrayList; public class ModelConfig { - private int minWorkers = 1; - private int maxWorkers = 1; - private int batchSize = 1; - private int maxBatchDelay = 100; - private int responseTimeout = 120; - private ArrayList gpuIds; + private int minWorkers; + private int maxWorkers; + private int batchSize; + private int maxBatchDelay; + private int responseTimeout; + private CoreType coreType = CoreType.NONE; + private ArrayList coreIds; private int parallelLevel = 1; + private ParallelType parallelType = ParallelType.NONE; public int getMinWorkers() { return minWorkers; @@ -51,12 +53,12 @@ public void setResponseTimeout(int responseTimeout) { this.responseTimeout = responseTimeout; } - public ArrayList getGpuIds() { - return gpuIds; + public ArrayList getCoreIds() { + return coreIds; } - public void setGpuIds(ArrayList gpuIds) { - this.gpuIds = gpuIds; + public void setCoreIds(ArrayList coreIds) { + this.coreIds = coreIds; } public int getParallelLevel() { @@ -66,4 +68,54 @@ public int getParallelLevel() { public void setParallelLevel(int parallelLevel) { this.parallelLevel = parallelLevel; } + + public void setParallelType(String parallelType) { + this.parallelType = ParallelType.valueOf(parallelType); + } + + public ParallelType getParallelType() { + return parallelType; + } + + public void setCoreType(String coreType) { + this.coreType = CoreType.valueOf(coreType); + } + + public CoreType getCoreType() { + return coreType; + } + + public enum ParallelType { + NONE(""), + PP("pp"), + TP("tp"), + PPTP("pptp"); + + private String type; + + ParallelType(String type) { + this.type = type; + } + + public String getParallelType() { + return type; + } + } + + public enum CoreType { + NONE(""), + CPU("cpu"), + GPU("gpu"), + NEURON("neuron"); + + private String type; + + CoreType(String type) { + this.type = type; + } + + public String getCoreType() { + return type; + } + } } diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java index 6e8bb7c689..c6c1c39104 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java @@ -12,14 +12,11 @@ import java.nio.charset.StandardCharsets; import java.nio.file.FileAlreadyExistsException; import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; import java.util.List; import java.util.regex.Pattern; import org.apache.commons.io.FileUtils; import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.model.InvalidModelException; -import org.pytorch.serve.archive.model.ModelConfig; import org.pytorch.serve.archive.s3.HttpUtils; import org.yaml.snakeyaml.Yaml; import org.yaml.snakeyaml.constructor.Constructor; @@ -47,11 +44,11 @@ public static T readFile(File file, Class type) public static T readYamlFile(File file, Class type) throws InvalidModelException, IOException { - //Yaml yaml = new Yaml(new Constructor(ModelConfig.class)); + // Yaml yaml = new Yaml(new Constructor(ModelConfig.class)); Yaml yaml = new Yaml(new Constructor(type)); try (Reader r = - new InputStreamReader( - Files.newInputStream(file.toPath()), StandardCharsets.UTF_8)) { + new InputStreamReader( + Files.newInputStream(file.toPath()), StandardCharsets.UTF_8)) { return yaml.load(r); } catch (YAMLException e) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java index 2ceee181d6..9a284febc2 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java @@ -52,21 +52,21 @@ import org.pytorch.serve.wlm.Model; import org.pytorch.serve.wlm.ModelManager; import org.pytorch.serve.wlm.WorkLoadManager; +import org.pytorch.serve.wlm.WorkerInitializationException; import org.pytorch.serve.workflow.WorkflowManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class ModelServer { + public static final int MAX_RCVBUF_SIZE = 4096; private Logger logger = LoggerFactory.getLogger(ModelServer.class); - private ServerGroups serverGroups; private Server inferencegRPCServer; private Server managementgRPCServer; private List futures = new ArrayList<>(2); private AtomicBoolean stopped = new AtomicBoolean(false); private ConfigManager configManager; - public static final int MAX_RCVBUF_SIZE = 4096; /** Creates a new {@code ModelServer} instance. */ public ModelServer(ConfigManager configManager) { @@ -192,26 +192,39 @@ private void initModelStore() throws InvalidSnapshotException, IOException { ModelArchive archive = modelManager.registerModel(file.getName(), defaultModelName); - modelManager.updateModel( - archive.getModelName(), - archive.getModelVersion(), + int minWorkers = configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.MIN_WORKERS, - workers), + workers); + int maxWorkers = configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.MAX_WORKERS, - workers), + workers); + if (archive.getModelConfig() != null) { + int marMinWorkers = archive.getModelConfig().getMinWorkers(); + int marMaxWorkers = archive.getModelConfig().getMaxWorkers(); + if (marMinWorkers > 0 && marMaxWorkers >= marMinWorkers) { + minWorkers = marMinWorkers; + maxWorkers = marMaxWorkers; + } + } + modelManager.updateModel( + archive.getModelName(), + archive.getModelVersion(), + minWorkers, + maxWorkers, true, false); startupModels.add(archive.getModelName()); } catch (ModelException | IOException | InterruptedException - | DownloadArchiveException e) { + | DownloadArchiveException + | WorkerInitializationException e) { logger.warn("Failed to load model: " + file.getAbsolutePath(), e); } } @@ -251,26 +264,39 @@ private void initModelStore() throws InvalidSnapshotException, IOException { false, false, false); - modelManager.updateModel( - archive.getModelName(), - archive.getModelVersion(), + int minWorkers = configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.MIN_WORKERS, - workers), + workers); + int maxWorkers = configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.MAX_WORKERS, - workers), + workers); + if (archive.getModelConfig() != null) { + int marMinWorkers = archive.getModelConfig().getMinWorkers(); + int marMaxWorkers = archive.getModelConfig().getMaxWorkers(); + if (marMinWorkers > 0 && marMaxWorkers >= marMinWorkers) { + minWorkers = marMinWorkers; + maxWorkers = marMaxWorkers; + } + } + modelManager.updateModel( + archive.getModelName(), + archive.getModelVersion(), + minWorkers, + maxWorkers, true, false); startupModels.add(archive.getModelName()); } catch (ModelException | IOException | InterruptedException - | DownloadArchiveException e) { + | DownloadArchiveException + | WorkerInitializationException e) { logger.warn("Failed to load model: " + url, e); } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java index a034f600e0..3ad18221dc 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java +++ b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java @@ -27,12 +27,21 @@ import org.pytorch.serve.util.JsonUtils; import org.pytorch.serve.util.messages.RequestInput; import org.pytorch.serve.wlm.ModelManager; +import org.pytorch.serve.wlm.WorkerInitializationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class ManagementImpl extends ManagementAPIsServiceImplBase { private static final Logger logger = LoggerFactory.getLogger(ManagementImpl.class); + public static void sendErrorResponse( + StreamObserver responseObserver, Status status, Exception e) { + responseObserver.onError( + status.withDescription(e.getMessage()) + .augmentDescription(e.getClass().getCanonicalName()) + .asRuntimeException()); + } + @Override public void describeModel( DescribeModelRequest request, StreamObserver responseObserver) { @@ -117,7 +126,7 @@ public void registerModel( sendStatusResponse(responseObserver, statusResponse); } catch (InternalServerException e) { sendException(responseObserver, e, null); - } catch (ExecutionException | InterruptedException e) { + } catch (ExecutionException | InterruptedException | WorkerInitializationException e) { sendException(responseObserver, e, "Error while creating workers"); } catch (ModelNotFoundException | ModelVersionNotFoundException e) { sendErrorResponse(responseObserver, Status.NOT_FOUND, e); @@ -156,7 +165,7 @@ public void scaleWorker( false, null); sendStatusResponse(responseObserver, statusResponse); - } catch (ExecutionException | InterruptedException e) { + } catch (ExecutionException | InterruptedException | WorkerInitializationException e) { sendException(responseObserver, e, "Error while creating workers"); } catch (ModelNotFoundException | ModelVersionNotFoundException e) { sendErrorResponse(responseObserver, Status.NOT_FOUND, e); @@ -241,14 +250,6 @@ private void sendErrorResponse( .asRuntimeException()); } - public static void sendErrorResponse( - StreamObserver responseObserver, Status status, Exception e) { - responseObserver.onError( - status.withDescription(e.getMessage()) - .augmentDescription(e.getClass().getCanonicalName()) - .asRuntimeException()); - } - private void sendStatusResponse( StreamObserver responseObserver, StatusResponse statusResponse) { int httpResponseStatusCode = statusResponse.getHttpResponseCode(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/HttpRequestHandlerChain.java b/frontend/server/src/main/java/org/pytorch/serve/http/HttpRequestHandlerChain.java index 8a381bcfca..7219c5b81b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/HttpRequestHandlerChain.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/HttpRequestHandlerChain.java @@ -20,6 +20,7 @@ import org.pytorch.serve.servingsdk.impl.ModelServerResponse; import org.pytorch.serve.util.NettyUtils; import org.pytorch.serve.wlm.ModelManager; +import org.pytorch.serve.wlm.WorkerInitializationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,7 +46,7 @@ public abstract void handleRequest( QueryStringDecoder decoder, String[] segments) throws ModelNotFoundException, ModelException, DownloadArchiveException, - WorkflowException; + WorkflowException, WorkerInitializationException; private void run( ModelServerEndpoint endpoint, diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java index 9b99a826dd..05422dafa8 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java @@ -12,6 +12,7 @@ import org.pytorch.serve.openapi.OpenApiUtils; import org.pytorch.serve.util.ConnectorType; import org.pytorch.serve.util.NettyUtils; +import org.pytorch.serve.wlm.WorkerInitializationException; public class ApiDescriptionRequestHandler extends HttpRequestHandlerChain { @@ -27,7 +28,8 @@ public void handleRequest( FullHttpRequest req, QueryStringDecoder decoder, String[] segments) - throws ModelException, DownloadArchiveException, WorkflowException { + throws ModelException, DownloadArchiveException, WorkflowException, + WorkerInitializationException { if (isApiDescription(segments)) { String path = decoder.path(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java index 10774a1206..308c796a0a 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java @@ -31,6 +31,7 @@ import org.pytorch.serve.util.messages.RequestInput; import org.pytorch.serve.wlm.Model; import org.pytorch.serve.wlm.ModelManager; +import org.pytorch.serve.wlm.WorkerInitializationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,7 +55,8 @@ public void handleRequest( FullHttpRequest req, QueryStringDecoder decoder, String[] segments) - throws ModelException, DownloadArchiveException, WorkflowException { + throws ModelException, DownloadArchiveException, WorkflowException, + WorkerInitializationException { if (isInferenceReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { handleCustomEndpoint(ctx, req, segments, decoder); diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java index 913708428f..29a6f156cf 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java @@ -38,6 +38,7 @@ import org.pytorch.serve.util.messages.WorkerCommands; import org.pytorch.serve.wlm.Model; import org.pytorch.serve.wlm.ModelManager; +import org.pytorch.serve.wlm.WorkerInitializationException; import org.pytorch.serve.wlm.WorkerThread; /** @@ -58,7 +59,8 @@ public void handleRequest( FullHttpRequest req, QueryStringDecoder decoder, String[] segments) - throws ModelException, DownloadArchiveException, WorkflowException { + throws ModelException, DownloadArchiveException, WorkflowException, + WorkerInitializationException { if (isManagementReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { handleCustomEndpoint(ctx, req, segments, decoder); @@ -191,7 +193,7 @@ private KFV1ModelReadyResponse createKFV1ModelReadyResponse( private void handleRegisterModel( ChannelHandlerContext ctx, QueryStringDecoder decoder, FullHttpRequest req) - throws ModelException, DownloadArchiveException { + throws ModelException, DownloadArchiveException, WorkerInitializationException { RegisterModelRequest registerModelRequest = parseRequest(req, decoder); StatusResponse statusResponse; try { @@ -225,7 +227,8 @@ private void handleScaleModel( QueryStringDecoder decoder, String modelName, String modelVersion) - throws ModelNotFoundException, ModelVersionNotFoundException { + throws ModelNotFoundException, ModelVersionNotFoundException, + WorkerInitializationException { int minWorkers = NettyUtils.getIntParameter(decoder, "min_worker", 1); int maxWorkers = NettyUtils.getIntParameter(decoder, "max_worker", minWorkers); if (modelVersion == null) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java index 41658e6909..9760babd46 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java @@ -25,6 +25,7 @@ import org.pytorch.serve.archive.workflow.WorkflowException; import org.pytorch.serve.http.HttpRequestHandlerChain; import org.pytorch.serve.util.NettyUtils; +import org.pytorch.serve.wlm.WorkerInitializationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,7 +45,8 @@ public void handleRequest( FullHttpRequest req, QueryStringDecoder decoder, String[] segments) - throws ModelException, DownloadArchiveException, WorkflowException { + throws ModelException, DownloadArchiveException, WorkflowException, + WorkerInitializationException { if (segments.length >= 2 && "metrics".equals(segments[1])) { ByteBuf resBuf = Unpooled.directBuffer(); List params = diff --git a/frontend/server/src/main/java/org/pytorch/serve/snapshot/SnapshotManager.java b/frontend/server/src/main/java/org/pytorch/serve/snapshot/SnapshotManager.java index c6d910d59d..6d0aca353f 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/snapshot/SnapshotManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/snapshot/SnapshotManager.java @@ -17,6 +17,7 @@ import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.wlm.Model; import org.pytorch.serve.wlm.ModelManager; +import org.pytorch.serve.wlm.WorkerInitializationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -139,7 +140,10 @@ private void initModels(Snapshot snapshot) { } catch (IOException e) { logger.error("Error while retrieving snapshot details. Details: {}", e.getMessage()); - } catch (ModelException | InterruptedException | DownloadArchiveException e) { + } catch (ModelException + | InterruptedException + | DownloadArchiveException + | WorkerInitializationException e) { logger.error("Error while registering model. Details: {}", e.getMessage()); } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java b/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java index c7adbca786..586fa0167c 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java @@ -35,6 +35,7 @@ import org.pytorch.serve.wlm.Model; import org.pytorch.serve.wlm.ModelManager; import org.pytorch.serve.wlm.ModelVersionedRefs; +import org.pytorch.serve.wlm.WorkerInitializationException; import org.pytorch.serve.wlm.WorkerState; import org.pytorch.serve.wlm.WorkerThread; @@ -108,7 +109,7 @@ public static String setDefault(String modelName, String newModelVersion) public static StatusResponse registerModel(RegisterModelRequest registerModelRequest) throws ModelException, InternalServerException, ExecutionException, - InterruptedException, DownloadArchiveException { + InterruptedException, DownloadArchiveException, WorkerInitializationException { String modelUrl = registerModelRequest.getModelUrl(); if (modelUrl == null) { throw new BadRequestException("Parameter url is required."); @@ -162,7 +163,7 @@ public static StatusResponse handleRegister( boolean isWorkflowModel, boolean s3SseKms) throws ModelException, ExecutionException, InterruptedException, - DownloadArchiveException { + DownloadArchiveException, WorkerInitializationException { ModelManager modelManager = ModelManager.getInstance(); final ModelArchive archive; @@ -188,7 +189,17 @@ public static StatusResponse handleRegister( } modelName = archive.getModelName(); - if (initialWorkers <= 0) { + int minWorkers = 0; + int maxWorkers = 0; + if (archive.getModelConfig() != null) { + int marMinWorkers = archive.getModelConfig().getMinWorkers(); + int marMaxWorkers = archive.getModelConfig().getMaxWorkers(); + if (marMinWorkers > 0 && marMaxWorkers >= marMinWorkers) { + minWorkers = marMinWorkers; + maxWorkers = marMaxWorkers; + } + } + if (initialWorkers <= 0 && minWorkers == 0) { final String msg = "Model \"" + modelName @@ -200,12 +211,14 @@ public static StatusResponse handleRegister( } return new StatusResponse(msg, HttpURLConnection.HTTP_OK); } + minWorkers = minWorkers > 0 ? minWorkers : initialWorkers; + maxWorkers = maxWorkers > 0 ? maxWorkers : initialWorkers; return ApiUtils.updateModelWorkers( modelName, archive.getModelVersion(), - initialWorkers, - initialWorkers, + minWorkers, + maxWorkers, isSync, true, f -> { @@ -223,7 +236,7 @@ public static StatusResponse updateModelWorkers( boolean isInit, final Function onError) throws ModelVersionNotFoundException, ModelNotFoundException, ExecutionException, - InterruptedException { + InterruptedException, WorkerInitializationException { ModelManager modelManager = ModelManager.getInstance(); if (maxWorkers < minWorkers) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index b835e83df1..60dd3da4e0 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -14,6 +14,7 @@ import java.util.concurrent.locks.ReentrantLock; import org.apache.commons.io.FilenameUtils; import org.pytorch.serve.archive.model.ModelArchive; +import org.pytorch.serve.archive.model.ModelConfig; import org.pytorch.serve.job.Job; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.messages.WorkerCommands; @@ -38,9 +39,17 @@ public class Model { private int minWorkers; private int maxWorkers; private int batchSize; + private int marBatchSize; private int maxBatchDelay; + private int marMaxBatchDelay; private int parallelLevel = 1; - private ArrayList gpuIds; + private ModelConfig.ParallelType parallelType = ModelConfig.ParallelType.NONE; + private ModelConfig.CoreType coreType = + ConfigManager.getInstance().getNumberOfGpu() > 0 + ? ModelConfig.CoreType.GPU + : ModelConfig.CoreType.CPU; + private ArrayList coreIds; + private int numCores; private ReentrantLock lock; private int responseTimeout; private ModelVersionName modelVersionName; @@ -56,17 +65,32 @@ public class Model { public Model(ModelArchive modelArchive, int queueSize) { this.modelArchive = modelArchive; if (modelArchive != null && modelArchive.getModelConfig() != null) { - minWorkers = modelArchive.getModelConfig().getMinWorkers(); - maxWorkers = modelArchive.getModelConfig().getMaxWorkers(); - batchSize = modelArchive.getModelConfig().getBatchSize(); - maxBatchDelay = modelArchive.getModelConfig().getMaxBatchDelay(); - responseTimeout = modelArchive.getModelConfig().getResponseTimeout(); - parallelLevel = modelArchive.getModelConfig().getParallelLevel(); - gpuIds = modelArchive.getModelConfig().getGpuIds(); + if (modelArchive.getModelConfig().getParallelLevel() > 1 + && modelArchive.getModelConfig().getParallelType() + != ModelConfig.ParallelType.NONE) { + parallelLevel = modelArchive.getModelConfig().getParallelLevel(); + parallelType = modelArchive.getModelConfig().getParallelType(); + } + if (modelArchive.getModelConfig().getCoreType() != ModelConfig.CoreType.NONE) { + coreType = + (modelArchive.getModelConfig().getCoreType() == ModelConfig.CoreType.GPU + && ConfigManager.getInstance().getNumberOfGpu() > 0) + ? ModelConfig.CoreType.GPU + : coreType; + } + coreIds = modelArchive.getModelConfig().getCoreIds(); } else { batchSize = 1; maxBatchDelay = 100; } + + if (ConfigManager.getInstance().getNumberOfGpu() > 0) { + numCores = + (coreIds != null && coreIds.size() > 0) + ? coreIds.size() + : ConfigManager.getInstance().getNumberOfGpu(); + } + jobsDb = new ConcurrentHashMap<>(); // Always have a queue for data jobsDb.putIfAbsent(DEFAULT_DATA_QUEUE, new LinkedBlockingDeque<>(queueSize)); @@ -269,19 +293,31 @@ public void setResponseTimeout(int responseTimeout) { this.responseTimeout = responseTimeout; } - public ArrayList getGpuIds() { - return this.gpuIds; + public ArrayList getCoreIds() { + return this.coreIds; } - public void setGpuIds(ArrayList gpuIds) { - Collections.copy(this.gpuIds, gpuIds); + public void setCoreIdsIds(ArrayList coreIds) { + Collections.copy(this.coreIds, coreIds); + } + + public int getParallelLevel() { + return this.parallelLevel; } public void setParallelLevel(int parallelLevel) { this.parallelLevel = parallelLevel; } - public int getParallelLevel() { - return this.parallelLevel; + public ModelConfig.ParallelType getParallelType() { + return this.parallelType; + } + + public ModelConfig.CoreType getCoreType() { + return this.coreType; + } + + public int getNumCores() { + return this.numCores; } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java index 7245241cab..63e26ef20b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java @@ -21,6 +21,7 @@ import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.model.Manifest; import org.pytorch.serve.archive.model.ModelArchive; +import org.pytorch.serve.archive.model.ModelConfig; import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.archive.model.ModelNotFoundException; import org.pytorch.serve.archive.model.ModelVersionNotFoundException; @@ -52,10 +53,6 @@ private ModelManager(ConfigManager configManager, WorkLoadManager wlm) { this.startupModels = new HashSet<>(); } - public ScheduledExecutorService getScheduler() { - return scheduler; - } - public static void init(ConfigManager configManager, WorkLoadManager wlm) { modelManager = new ModelManager(configManager, wlm); } @@ -64,6 +61,10 @@ public static ModelManager getInstance() { return modelManager; } + public ScheduledExecutorService getScheduler() { + return scheduler; + } + public ModelArchive registerModel(String url, String defaultModelName) throws ModelException, IOException, InterruptedException, DownloadArchiveException { return registerModel( @@ -81,7 +82,8 @@ public ModelArchive registerModel(String url, String defaultModelName) } public void registerAndUpdateModel(String modelName, JsonObject modelInfo) - throws ModelException, IOException, InterruptedException, DownloadArchiveException { + throws ModelException, IOException, InterruptedException, DownloadArchiveException, + WorkerInitializationException { boolean defaultVersion = modelInfo.get(Model.DEFAULT_VERSION).getAsBoolean(); String url = modelInfo.get(Model.MAR_NAME).getAsString(); @@ -266,24 +268,44 @@ private Model createModel( boolean isWorkflowModel) { Model model = new Model(archive, configManager.getJobQueueSize()); - model.setBatchSize( - configManager.getJsonIntValue( - archive.getModelName(), - archive.getModelVersion(), - Model.BATCH_SIZE, - batchSize)); - model.setMaxBatchDelay( - configManager.getJsonIntValue( - archive.getModelName(), - archive.getModelVersion(), - Model.MAX_BATCH_DELAY, - maxBatchDelay)); - model.setResponseTimeout( - configManager.getJsonIntValue( - archive.getModelName(), - archive.getModelVersion(), - Model.RESPONSE_TIMEOUT, - responseTimeout)); + if (archive.getModelConfig() != null) { + int marBatchSize = archive.getModelConfig().getBatchSize(); + batchSize = + marBatchSize > 0 + ? marBatchSize + : configManager.getJsonIntValue( + archive.getModelName(), + archive.getModelVersion(), + Model.BATCH_SIZE, + batchSize); + } + model.setBatchSize(batchSize); + + if (archive.getModelConfig() != null) { + int marMaxBatchDelay = archive.getModelConfig().getMaxBatchDelay(); + maxBatchDelay = + marMaxBatchDelay > 0 + ? marMaxBatchDelay + : configManager.getJsonIntValue( + archive.getModelName(), + archive.getModelVersion(), + Model.MAX_BATCH_DELAY, + maxBatchDelay); + } + model.setMaxBatchDelay(maxBatchDelay); + + if (archive.getModelConfig() != null) { + int marResponseTimeout = archive.getModelConfig().getResponseTimeout(); + responseTimeout = + marResponseTimeout > 0 + ? marResponseTimeout + : configManager.getJsonIntValue( + archive.getModelName(), + archive.getModelVersion(), + Model.RESPONSE_TIMEOUT, + responseTimeout); + } + model.setResponseTimeout(responseTimeout); model.setWorkflowModel(isWorkflowModel); return model; @@ -379,7 +401,7 @@ public void setDefaultVersion(String modelName, String newModelVersion) private CompletableFuture updateModel( String modelName, String versionId, boolean isStartup) - throws ModelVersionNotFoundException { + throws ModelVersionNotFoundException, WorkerInitializationException { Model model = getVersionModel(modelName, versionId); return updateModel( modelName, @@ -397,14 +419,39 @@ public CompletableFuture updateModel( int maxWorkers, boolean isStartup, boolean isCleanUp) - throws ModelVersionNotFoundException { + throws ModelVersionNotFoundException, WorkerInitializationException { Model model = getVersionModel(modelName, versionId); if (model == null) { throw new ModelVersionNotFoundException( "Model version: " + versionId + " does not exist for model: " + modelName); } - + if (model.getParallelLevel() > 1 && model.getCoreType() == ModelConfig.CoreType.GPU) { + /** + * Current capacity check for LMI is based on single node. TODO: multiple nodes check + * will be based on --proc-per-node + numCores. + */ + int capacity = model.getNumCores() / model.getParallelLevel(); + if (capacity == 0) { + logger.error( + "there are no enough gpu devices to support this parallelLever: {}", + model.getParallelLevel()); + throw new WorkerInitializationException( + "No enough gpu devices for model:" + + modelName + + " parallelLevel:" + + model.getParallelLevel()); + } else { + minWorkers = minWorkers > capacity ? capacity : minWorkers; + maxWorkers = maxWorkers > capacity ? capacity : maxWorkers; + logger.info( + "model {} set minWorkers: {}, maxWorkers: {} for parallelLevel: {} ", + modelName, + minWorkers, + maxWorkers, + model.getParallelLevel()); + } + } model.setMinWorkers(minWorkers); model.setMaxWorkers(maxWorkers); logger.debug("updateModel: {}, count: {}", modelName, minWorkers); @@ -423,7 +470,7 @@ private Model getVersionModel(String modelName, String versionId) { public CompletableFuture updateModel( String modelName, String versionId, int minWorkers, int maxWorkers) - throws ModelVersionNotFoundException { + throws ModelVersionNotFoundException, WorkerInitializationException { return updateModel(modelName, versionId, minWorkers, maxWorkers, false, false); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index b630f28bd0..4758b8bc28 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -22,18 +22,15 @@ public class WorkLoadManager { + private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class); private ExecutorService threadPool; - private ConcurrentHashMap> workers; - private ConfigManager configManager; private EventLoopGroup backendGroup; private AtomicInteger port; private AtomicInteger distributionPort; private AtomicInteger gpuCounter; - private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class); - public WorkLoadManager(ConfigManager configManager, EventLoopGroup backendGroup) { this.configManager = configManager; this.backendGroup = backendGroup; @@ -195,8 +192,8 @@ private void addThreads( List threads, Model model, int count, CompletableFuture future) { WorkerStateListener listener = new WorkerStateListener(future, count); int maxGpu = configManager.getNumberOfGpu(); - if (maxGpu > 0 && model.getGpuIds() != null) { - maxGpu = model.getGpuIds().size(); + if (maxGpu > 0 && model.getCoreIds() != null) { + maxGpu = model.getCoreIds().size(); } int parallelGpuIdx = 0; for (int i = 0; i < count; ++i) { @@ -205,13 +202,13 @@ private void addThreads( if (maxGpu > 0) { if (model.getParallelLevel() > 1) { gpuId = - model.getGpuIds() != null - ? model.getGpuIds().get(parallelGpuIdx) + model.getCoreIds() != null + ? model.getCoreIds().get(parallelGpuIdx) : parallelGpuIdx; parallelGpuIdx += model.getParallelLevel(); } else { - if (model.getGpuIds() != null) { - gpuId = model.getGpuIds().get(parallelGpuIdx++ % maxGpu); + if (model.getCoreIds() != null) { + gpuId = model.getCoreIds().get(parallelGpuIdx++ % maxGpu); } else { gpuId = gpuCounter.accumulateAndGet( diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index 795a4041e2..014f0dfafe 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -11,7 +11,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.regex.Matcher; import java.util.regex.Pattern; - import org.pytorch.serve.metrics.Metric; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.Connector; @@ -189,7 +188,7 @@ private void attachRunner(ArrayList argl, int port) { argl.add("--rdzv_endpoint=localhost:" + port); argl.add("--rdzv_id=" + model.getModelName() + "_" + port); } - + private void attachPippyArg(ArrayList argl, int port, int parallelLevel) { argl.add("--master_addr"); argl.add("localhost"); @@ -241,21 +240,20 @@ private synchronized void setPort(int port) { } private static final class ReaderThread extends Thread { - private static final Pattern METRIC_PATTERN = Pattern.compile( - "^(INFO > )?(\\[METRICS])(.*)"); - private static final Pattern WORKER_START_PATTERN = Pattern.compile( - "^(INFO > )?(Torch worker started.)$"); - private static final Pattern WORKER_PID_PATTERN = Pattern.compile( - "^(INFO > )?(\\[PID])(\\d+)$"); - - private InputStream is; - private boolean error; - private WorkerLifeCycle lifeCycle; - private AtomicBoolean isRunning = new AtomicBoolean(true); + private static final Pattern METRIC_PATTERN = + Pattern.compile("^(INFO > )?(\\[METRICS])(.*)"); + private static final Pattern WORKER_START_PATTERN = + Pattern.compile("^(INFO > )?(Torch worker started.)$"); + private static final Pattern WORKER_PID_PATTERN = + Pattern.compile("^(INFO > )?(\\[PID])(\\d+)$"); private static final Logger loggerModelMetrics = LoggerFactory.getLogger(ConfigManager.MODEL_METRICS_LOGGER); private static final Logger loggerModelOutput = LoggerFactory.getLogger(ConfigManager.MODEL_LOGGER); + private InputStream is; + private boolean error; + private WorkerLifeCycle lifeCycle; + private AtomicBoolean isRunning = new AtomicBoolean(true); public ReaderThread(String name, InputStream is, boolean error, WorkerLifeCycle lifeCycle) { super(name + (error ? "-stderr" : "-stdout")); diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java index 21a2ccfd30..08493ad5fe 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import org.pytorch.serve.archive.model.ModelConfig; import org.pytorch.serve.job.Job; import org.pytorch.serve.job.RestJob; import org.pytorch.serve.metrics.Dimension; @@ -46,17 +47,13 @@ public class WorkerThread implements Runnable { LoggerFactory.getLogger(ConfigManager.MODEL_SERVER_METRICS_LOGGER); private static final Logger loggerTelemetryMetrics = LoggerFactory.getLogger(ConfigManager.MODEL_SERVER_TELEMETRY_LOGGER); - - private Metric workerLoadTime; - private static final int[] BACK_OFF = { 0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597 }; - private static final long WORKER_TIMEOUT = 2L; private static final ModelRequestEncoder ENCODER = new ModelRequestEncoder(ConfigManager.getInstance().getPreferDirectBuffer()); - + private Metric workerLoadTime; private ConfigManager configManager; private EventLoopGroup backendEventGroup; private int port; @@ -80,6 +77,34 @@ public class WorkerThread implements Runnable { private WorkerLifeCycle lifeCycle; + public WorkerThread( + ConfigManager configManager, + EventLoopGroup backendEventGroup, + int port, + int gpuId, + Model model, + BatchAggregator aggregator, + WorkerStateListener listener) { + this.workerId = String.valueOf(port); // Unique across all workers. + this.configManager = configManager; + this.backendEventGroup = backendEventGroup; + this.port = port; + this.model = model; + this.aggregator = aggregator; + this.gpuId = gpuId; + this.listener = listener; + startTime = System.currentTimeMillis(); + lifeCycle = new WorkerLifeCycle(configManager, model); + replies = new ArrayBlockingQueue<>(model.getParallelLevel()); + workerLoadTime = + new Metric( + getWorkerName(), + String.valueOf(System.currentTimeMillis()), + "ms", + ConfigManager.getInstance().getHostName(), + new Dimension("Level", "Host")); + } + public WorkerState getState() { return state; } @@ -141,34 +166,6 @@ public WorkerLifeCycle getLifeCycle() { return lifeCycle; } - public WorkerThread( - ConfigManager configManager, - EventLoopGroup backendEventGroup, - int port, - int gpuId, - Model model, - BatchAggregator aggregator, - WorkerStateListener listener) { - this.workerId = String.valueOf(port); // Unique across all workers. - this.configManager = configManager; - this.backendEventGroup = backendEventGroup; - this.port = port; - this.model = model; - this.aggregator = aggregator; - this.gpuId = gpuId; - this.listener = listener; - startTime = System.currentTimeMillis(); - lifeCycle = new WorkerLifeCycle(configManager, model); - replies = new ArrayBlockingQueue<>(model.getParallelLevel()); - workerLoadTime = - new Metric( - getWorkerName(), - String.valueOf(System.currentTimeMillis()), - "ms", - ConfigManager.getInstance().getHostName(), - new Dimension("Level", "Host")); - } - @Override public void run() { int responseTimeout = model.getResponseTimeout(); @@ -186,7 +183,14 @@ public void run() { long wtStartTime = System.currentTimeMillis(); logger.info("Flushing req. to backend at: " + wtStartTime); - int repeats = req.getCommand() == WorkerCommands.LOAD ? model.getParallelLevel() : 1; + int repeats = + (req.getCommand() == WorkerCommands.LOAD) + || (req.getCommand() == WorkerCommands.PREDICT + && model.getParallelLevel() > 1 + && model.getParallelType() + != ModelConfig.ParallelType.PP) + ? model.getParallelLevel() + : 1; for (int i = 0; backendChannel.size() > 0 && i < repeats; i++) { backendChannel.get(i).writeAndFlush(req).sync(); } @@ -321,7 +325,7 @@ private void connect() throws WorkerInitializationException, InterruptedExceptio final int parallelLevel = model.getParallelLevel(); final CountDownLatch latch = new CountDownLatch(parallelLevel); final int responseBufferSize = configManager.getMaxResponseSize(); - try { + try { for (int i = 0; i < parallelLevel; i++) { Connector connector = new Connector(port + i); Bootstrap b = new Bootstrap(); @@ -366,13 +370,14 @@ public void initChannel(Channel ch) { // TODO: // use gpu, batch size in load model command if (latch.getCount() == 1) { - RequestInput input = - new RequestInput(UUID.randomUUID().toString()); - if (gpuId >= 0) { - input.addParameter( - new InputParameter( - "gpu", String.valueOf(gpuId))); - } + RequestInput input = + new RequestInput( + UUID.randomUUID().toString()); + if (gpuId >= 0) { + input.addParameter( + new InputParameter( + "gpu", String.valueOf(gpuId))); + } Job job = new RestJob( @@ -392,13 +397,13 @@ public void initChannel(Channel ch) { "Worker failed to initialize within " + WORKER_TIMEOUT + " mins"); } running.set(true); - } catch (Throwable t) { - // https://github.com/netty/netty/issues/2597 + } catch (Throwable t) { + /* https://github.com/netty/netty/issues/2597 */ if (t instanceof IOException) { throw new WorkerInitializationException("Failed to connect to worker.", t); } throw t; - } + } } public boolean isRunning() { diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java index 1c67f4d2b8..042bef68e1 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java @@ -22,6 +22,7 @@ import org.pytorch.serve.util.NettyUtils; import org.pytorch.serve.util.messages.InputParameter; import org.pytorch.serve.util.messages.RequestInput; +import org.pytorch.serve.wlm.WorkerInitializationException; import org.pytorch.serve.workflow.WorkflowManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,13 +40,46 @@ public class WorkflowInferenceRequestHandler extends HttpRequestHandlerChain { /** Creates a new {@code WorkflowInferenceRequestHandler} instance. */ public WorkflowInferenceRequestHandler() {} + private static RequestInput parseRequest(ChannelHandlerContext ctx, FullHttpRequest req) { + String requestId = NettyUtils.getRequestId(ctx.channel()); + RequestInput inputData = new RequestInput(requestId); + + CharSequence contentType = HttpUtil.getMimeType(req); + for (Map.Entry entry : req.headers().entries()) { + inputData.updateHeaders(entry.getKey(), entry.getValue()); + } + + if (HttpPostRequestDecoder.isMultipart(req) + || HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED.contentEqualsIgnoreCase( + contentType)) { + HttpDataFactory factory = + new DefaultHttpDataFactory(ConfigManager.getInstance().getMaxRequestSize()); + HttpPostRequestDecoder form = new HttpPostRequestDecoder(factory, req); + try { + while (form.hasNext()) { + inputData.addParameter(NettyUtils.getFormData(form.next())); + } + } catch (HttpPostRequestDecoder.EndOfDataDecoderException ignore) { + logger.trace("End of multipart items."); + } finally { + form.cleanFiles(); + form.destroy(); + } + } else { + byte[] content = NettyUtils.getBytes(req.content()); + inputData.addParameter(new InputParameter("body", content, contentType)); + } + return inputData; + } + @Override public void handleRequest( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) - throws ModelException, DownloadArchiveException, WorkflowException { + throws ModelException, DownloadArchiveException, WorkflowException, + WorkerInitializationException { if ("wfpredict".equalsIgnoreCase(segments[1])) { if (segments.length < 3) { throw new ResourceNotFoundException(); @@ -84,36 +118,4 @@ private void sendResponse(ChannelHandlerContext ctx, StatusResponse statusRespon } } } - - private static RequestInput parseRequest(ChannelHandlerContext ctx, FullHttpRequest req) { - String requestId = NettyUtils.getRequestId(ctx.channel()); - RequestInput inputData = new RequestInput(requestId); - - CharSequence contentType = HttpUtil.getMimeType(req); - for (Map.Entry entry : req.headers().entries()) { - inputData.updateHeaders(entry.getKey(), entry.getValue()); - } - - if (HttpPostRequestDecoder.isMultipart(req) - || HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED.contentEqualsIgnoreCase( - contentType)) { - HttpDataFactory factory = - new DefaultHttpDataFactory(ConfigManager.getInstance().getMaxRequestSize()); - HttpPostRequestDecoder form = new HttpPostRequestDecoder(factory, req); - try { - while (form.hasNext()) { - inputData.addParameter(NettyUtils.getFormData(form.next())); - } - } catch (HttpPostRequestDecoder.EndOfDataDecoderException ignore) { - logger.trace("End of multipart items."); - } finally { - form.cleanFiles(); - form.destroy(); - } - } else { - byte[] content = NettyUtils.getBytes(req.content()); - inputData.addParameter(new InputParameter("body", content, contentType)); - } - return inputData; - } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java index 3f3e599739..b50f5891b7 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java @@ -26,6 +26,7 @@ import org.pytorch.serve.http.StatusResponse; import org.pytorch.serve.util.JsonUtils; import org.pytorch.serve.util.NettyUtils; +import org.pytorch.serve.wlm.WorkerInitializationException; import org.pytorch.serve.workflow.WorkflowManager; import org.pytorch.serve.workflow.messages.DescribeWorkflowResponse; import org.pytorch.serve.workflow.messages.ListWorkflowResponse; @@ -41,13 +42,27 @@ public class WorkflowMgmtRequestHandler extends HttpRequestHandlerChain { /** Creates a new {@code WorkflowMgmtRequestHandler} instance. */ public WorkflowMgmtRequestHandler() {} + private static DescribeWorkflowResponse createWorkflowResponse( + String workflowName, WorkFlow workflow) { + DescribeWorkflowResponse response = new DescribeWorkflowResponse(); + response.setWorkflowName(workflowName); + response.setWorkflowUrl(workflow.getWorkflowArchive().getUrl()); + response.setBatchSize(workflow.getBatchSize()); + response.setMaxBatchDelay(workflow.getMaxBatchDelay()); + response.setMaxWorkers(workflow.getMaxWorkers()); + response.setMinWorkers(workflow.getMinWorkers()); + response.setWorkflowDag(workflow.getWorkflowDag()); + return response; + } + @Override public void handleRequest( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) - throws ModelException, DownloadArchiveException, WorkflowException { + throws ModelException, DownloadArchiveException, WorkflowException, + WorkerInitializationException { if (isManagementReq(segments)) { if (!"workflows".equals(segments[1])) { throw new ResourceNotFoundException(); @@ -194,17 +209,4 @@ private void sendResponse(ChannelHandlerContext ctx, StatusResponse statusRespon } } } - - private static DescribeWorkflowResponse createWorkflowResponse( - String workflowName, WorkFlow workflow) { - DescribeWorkflowResponse response = new DescribeWorkflowResponse(); - response.setWorkflowName(workflowName); - response.setWorkflowUrl(workflow.getWorkflowArchive().getUrl()); - response.setBatchSize(workflow.getBatchSize()); - response.setMaxBatchDelay(workflow.getMaxBatchDelay()); - response.setMaxWorkers(workflow.getMaxWorkers()); - response.setMinWorkers(workflow.getMinWorkers()); - response.setWorkflowDag(workflow.getWorkflowDag()); - return response; - } } diff --git a/ts/model_service_worker.py b/ts/model_service_worker.py index 639fcbd1bb..03b7ee792c 100644 --- a/ts/model_service_worker.py +++ b/ts/model_service_worker.py @@ -10,22 +10,20 @@ import platform import socket import sys -import uuid from ts.arg_parser import ArgParser from ts.metrics.metric_cache_yaml_impl import MetricsCacheYamlImpl from ts.model_loader import ModelLoaderFactory from ts.protocol.otf_message_handler import create_load_model_response, retrieve_msg -from pippy import run_pippy MAX_FAILURE_THRESHOLD = 5 SOCKET_ACCEPT_TIMEOUT = 300.0 DEBUG = False BENCHMARK = os.getenv("TS_BENCHMARK") BENCHMARK = BENCHMARK in ["True", "true", "TRUE"] -LOCAL_RANK = int(os.environ['LOCAL_RANK']) -WORLD_SIZE = int(os.environ['WORLD_SIZE']) -WORLD_RANK = int(os.environ['RANK']) +LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0)) +WORLD_SIZE = int(os.getenv('WORLD_SIZE', 0)) +WORLD_RANK = int(os.getenv('RANK', 0)) class TorchModelServiceWorker(object): """ @@ -46,7 +44,7 @@ def __init__( if s_name is None: raise ValueError("Wrong arguments passed. No socket name given.") s_name_parts = s_name.rsplit('.', 1) - print(f"part0={s_name_parts[0]}, part1={s_name_parts[1]}, pid={str(os.getpid())}") + logging.info(f"part0={s_name_parts[0]}, part1={s_name_parts[1]}, pid={str(os.getpid())}") s_name_new = s_name_parts[0] + '.' + str(int(s_name_parts[1]) + WORLD_RANK) self.sock_name, self.port = s_name_new, -1 try: @@ -65,8 +63,7 @@ def __init__( else: raise ValueError("Incomplete data provided") - #logging.info("Listening on port: %s", s_name) - print("Listening on port: "+ self.sock_name) + logging.info("Listening on port: %s", s_name) socket_family = socket.AF_INET if s_type == "tcp" else socket.AF_UNIX self.sock = socket.socket(socket_family, socket.SOCK_STREAM) self.metrics_cache = MetricsCacheYamlImpl(config_file_path=metrics_config) @@ -172,7 +169,6 @@ def run_server(self): Run the backend worker process and listen on a socket :return: """ - print(f"sock_name={self.sock_name}, sock_port={str(self.port)}") if not DEBUG: self.sock.settimeout(SOCKET_ACCEPT_TIMEOUT) @@ -180,15 +176,12 @@ def run_server(self): if self.sock_type == "unix": self.sock.bind(self.sock_name) - print(f"binded, self.sock_name={self.sock_name}") else: self.sock.bind((self.sock_name, int(self.port))) # self.sock.listen(1) self.sock.listen(128) - print(f"listened, pid={str(os.getpid())}, LOCAL_RANK={str(LOCAL_RANK)}") - #print("Torch worker started.") logging.info("[PID]%d", os.getpid()) logging.info("Torch worker started.") logging.info("Python runtime: %s", platform.python_version()) @@ -198,8 +191,7 @@ def run_server(self): # workaround error(35, 'Resource temporarily unavailable') on OSX cl_socket.setblocking(True) - #logging.info("Connection accepted: %s.", cl_socket.getsockname()) - print("Connection accepted: "+ cl_socket.getsockname()) + logging.info("Connection accepted: %s.", cl_socket.getsockname()) self.handle_connection(cl_socket) From e1ad19f3fe963a2f682ec7d127096e8f9a702799 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 9 Mar 2023 08:51:46 -0800 Subject: [PATCH 19/47] update model-config.yaml --- examples/Huggingface_Largemodels/model-config.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/Huggingface_Largemodels/model-config.yaml b/examples/Huggingface_Largemodels/model-config.yaml index 7e33b6a3b8..e3dcf74c27 100644 --- a/examples/Huggingface_Largemodels/model-config.yaml +++ b/examples/Huggingface_Largemodels/model-config.yaml @@ -2,6 +2,8 @@ minWorkers: 1 maxWorkers: 1 maxBatchDelay: 100 responseTimeout: 120 +coreType: cpu # cpu, gpu, neuron +coreIds: [0,1,2.3] # core index for gpu, neuron parallelLevel: 4 parallelType: pp # pp: pipeline parallel; tp: tensor parallel; tp+pp From 0cd1fa87f1b89f5b8935c0c9ad97039273077571 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 9 Mar 2023 15:29:06 -0800 Subject: [PATCH 20/47] update modelConfig --- .../serve/archive/model/ModelArchive.java | 2 +- .../serve/archive/model/ModelConfig.java | 46 +++++++++++++++---- .../serve/archive/utils/ArchiveUtils.java | 1 - .../java/org/pytorch/serve/wlm/Model.java | 4 +- .../pytorch/serve/wlm/WorkLoadManager.java | 14 ++---- 5 files changed, 43 insertions(+), 24 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index dc77315180..428845b366 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -191,7 +191,7 @@ public ModelConfig getModelConfig() { } catch (InvalidModelException | IOException e) { logger.error( "Failed to parse model config file {}", - manifest.getModel().getConfigFile()); + manifest.getModel().getConfigFile(), e); } } return this.modelConfig; diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index feb3a06d6d..4abdd2a960 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -1,6 +1,8 @@ package org.pytorch.serve.archive.model; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Optional; public class ModelConfig { private int minWorkers; @@ -8,9 +10,11 @@ public class ModelConfig { private int batchSize; private int maxBatchDelay; private int responseTimeout; + private String deviceType; private CoreType coreType = CoreType.NONE; - private ArrayList coreIds; + private ArrayList deviceIds; private int parallelLevel = 1; + private String parallelMode; private ParallelType parallelType = ParallelType.NONE; public int getMinWorkers() { @@ -53,12 +57,12 @@ public void setResponseTimeout(int responseTimeout) { this.responseTimeout = responseTimeout; } - public ArrayList getCoreIds() { - return coreIds; + public ArrayList getDeviceIds() { + return deviceIds; } - public void setCoreIds(ArrayList coreIds) { - this.coreIds = coreIds; + public void setDeviceIds(ArrayList deviceIds) { + this.deviceIds = deviceIds; } public int getParallelLevel() { @@ -69,16 +73,26 @@ public void setParallelLevel(int parallelLevel) { this.parallelLevel = parallelLevel; } - public void setParallelType(String parallelType) { - this.parallelType = ParallelType.valueOf(parallelType); + public void setParallelMode(String parallelMode) { + this.parallelMode = parallelMode; + this.parallelType = ParallelType.get(parallelMode).get(); + } + + public String getParallelMode() { + return this.parallelMode; } public ParallelType getParallelType() { - return parallelType; + return this.parallelType; + } + + public void setDeviceType(String deviceType) { + this.deviceType = deviceType; + this.coreType = CoreType.get(deviceType).get(); } - public void setCoreType(String coreType) { - this.coreType = CoreType.valueOf(coreType); + public String getDeviceType() { + return deviceType; } public CoreType getCoreType() { @@ -100,6 +114,12 @@ public enum ParallelType { public String getParallelType() { return type; } + + public static Optional get(String parallelType) { + return Arrays.stream(ParallelType.values()) + .filter(t -> t.type.equals(parallelType)) + .findFirst(); + } } public enum CoreType { @@ -117,5 +137,11 @@ public enum CoreType { public String getCoreType() { return type; } + + public static Optional get(String coreType) { + return Arrays.stream(CoreType.values()) + .filter(t -> t.type.equals(coreType)) + .findFirst(); + } } } diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java index c6c1c39104..5fd0817381 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java @@ -44,7 +44,6 @@ public static T readFile(File file, Class type) public static T readYamlFile(File file, Class type) throws InvalidModelException, IOException { - // Yaml yaml = new Yaml(new Constructor(ModelConfig.class)); Yaml yaml = new Yaml(new Constructor(type)); try (Reader r = new InputStreamReader( diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index 60dd3da4e0..d313ec15a0 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -39,9 +39,7 @@ public class Model { private int minWorkers; private int maxWorkers; private int batchSize; - private int marBatchSize; private int maxBatchDelay; - private int marMaxBatchDelay; private int parallelLevel = 1; private ModelConfig.ParallelType parallelType = ModelConfig.ParallelType.NONE; private ModelConfig.CoreType coreType = @@ -78,7 +76,7 @@ public Model(ModelArchive modelArchive, int queueSize) { ? ModelConfig.CoreType.GPU : coreType; } - coreIds = modelArchive.getModelConfig().getCoreIds(); + coreIds = modelArchive.getModelConfig().getDeviceIds(); } else { batchSize = 1; maxBatchDelay = 100; diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index 4758b8bc28..088d84cef5 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -14,6 +14,8 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; + +import org.pytorch.serve.archive.model.ModelConfig; import org.pytorch.serve.snapshot.SnapshotManager; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.OSUtils; @@ -191,20 +193,14 @@ public CompletableFuture modelChanged( private void addThreads( List threads, Model model, int count, CompletableFuture future) { WorkerStateListener listener = new WorkerStateListener(future, count); - int maxGpu = configManager.getNumberOfGpu(); - if (maxGpu > 0 && model.getCoreIds() != null) { - maxGpu = model.getCoreIds().size(); - } + int maxGpu = model.getNumCores(); int parallelGpuIdx = 0; for (int i = 0; i < count; ++i) { int gpuId = -1; - if (maxGpu > 0) { + if (maxGpu > 0 && model.getCoreType() == ModelConfig.CoreType.GPU) { if (model.getParallelLevel() > 1) { - gpuId = - model.getCoreIds() != null - ? model.getCoreIds().get(parallelGpuIdx) - : parallelGpuIdx; + gpuId = parallelGpuIdx; parallelGpuIdx += model.getParallelLevel(); } else { if (model.getCoreIds() != null) { From b78aedd8f59a23aea592ea537297643e2f7d3790 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 9 Mar 2023 16:42:49 -0800 Subject: [PATCH 21/47] update archiver readme --- .../org/pytorch/serve/wlm/WorkerLifeCycle.java | 11 ----------- ts/arg_parser.py | 14 -------------- 2 files changed, 25 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index 014f0dfafe..87d5b1d4f0 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -134,10 +134,6 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup argl.add("--metrics-config"); argl.add(configManager.getMetricsConfigPath()); - if (model.getParallelLevel() > 1) { - attachPippyArg(argl, port, model.getParallelLevel()); - } - String[] envp = EnvironmentUtils.getEnvString( workingDir.getAbsolutePath(), @@ -189,13 +185,6 @@ private void attachRunner(ArrayList argl, int port) { argl.add("--rdzv_id=" + model.getModelName() + "_" + port); } - private void attachPippyArg(ArrayList argl, int port, int parallelLevel) { - argl.add("--master_addr"); - argl.add("localhost"); - argl.add("--master_port"); - argl.add(Integer.toString(port)); - } - public synchronized void terminateIOStreams() { if (errReader != null) { logger.warn("terminateIOStreams() threadName={}", errReader.getName()); diff --git a/ts/arg_parser.py b/ts/arg_parser.py index 0dd2bafd4e..0a1d0595e1 100644 --- a/ts/arg_parser.py +++ b/ts/arg_parser.py @@ -128,20 +128,6 @@ def model_service_worker_args(): help="Metrics configuration file", ) - parser.add_argument( - "--master_addr", - dest="master_addr", - type=str, - help="pippy master addr", - ) - - parser.add_argument( - "--master_port", - dest="master_port", - type=int, - help="pippy master port", - ) - return parser @staticmethod From 6ab3e071c850052587ca7a71e799c804d47dce9e Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 9 Mar 2023 16:43:15 -0800 Subject: [PATCH 22/47] update archiver readme --- model-archiver/README.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/model-archiver/README.md b/model-archiver/README.md index 73571f3205..596d97b63e 100644 --- a/model-archiver/README.md +++ b/model-archiver/README.md @@ -59,7 +59,7 @@ $ torch-model-archiver -h usage: torch-model-archiver [-h] --model-name MODEL_NAME --version MODEL_VERSION_NUMBER --model-file MODEL_FILE_PATH --serialized-file MODEL_SERIALIZED_PATH --handler HANDLER [--runtime {python,python3}] - [--export-path EXPORT_PATH] [-f] [--requirements-file] + [--export-path EXPORT_PATH] [-f] [--requirements-file] [--config-file] Model Archiver Tool @@ -113,6 +113,7 @@ optional arguments: -r, --requirements-file Path to requirements.txt file containing a list of model specific python packages to be installed by TorchServe for seamless model serving. + -c, --config-file Path to a model config yaml file. ``` ## Artifact Details @@ -152,6 +153,21 @@ e.g. if your custom handler custom_image_classifier.py is in /home/serve/example For more details refer [default handler documentation](../docs/default_handlers.md) or [custom handler documentation](../docs/custom_service.md) +### Config file + +A model config yaml file. For example: + +``` +minWorkers: 1 +maxWorkers: 1 +maxBatchDelay: 100 +responseTimeout: 120 +parallelLevel: 4 +parallelMode: "pp" # pp: pipeline parallel; tp: tensor parallel; pptp: pipeline+tensor parallel +deviceType: "gpu" #cpu, gpu, neuron +deviceIds: [0,1,2,3] +``` + ## Creating a Model Archive **1. Download the torch model archiver source** From 2b490dc756e6f17e072894a59ba3046ea193402e Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 12 Mar 2023 21:24:50 -0700 Subject: [PATCH 23/47] update ModelConfig and extend backend to parse model config yaml --- .../pippy_pp_handler.py | 309 +++++++++++++++++ .../pippy_pptp_handler.py | 312 ++++++++++++++++++ .../serve/archive/model/ModelArchive.java | 7 +- .../serve/archive/model/ModelConfig.java | 86 +++-- .../serve/archive/utils/ArchiveUtils.java | 14 + .../java/org/pytorch/serve/wlm/Model.java | 28 +- .../org/pytorch/serve/wlm/ModelManager.java | 2 +- .../pytorch/serve/wlm/WorkLoadManager.java | 3 +- ts/context.py | 2 + ts/service.py | 11 + ts/utils/util.py | 7 +- 11 files changed, 735 insertions(+), 46 deletions(-) create mode 100644 examples/Huggingface_Largemodels/pippy_pp_handler.py create mode 100644 examples/Huggingface_Largemodels/pippy_pptp_handler.py diff --git a/examples/Huggingface_Largemodels/pippy_pp_handler.py b/examples/Huggingface_Largemodels/pippy_pp_handler.py new file mode 100644 index 0000000000..b726e4a8a8 --- /dev/null +++ b/examples/Huggingface_Largemodels/pippy_pp_handler.py @@ -0,0 +1,309 @@ +import json +import logging +import os +import zipfile +from abc import ABC + +import argparse +import inspect +import logging +import os +import time + +import pippy.fx +#from pippy import run_pippy +from pippy.IR import MultiUseParameterConfig, Pipe +from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverInterleaved1F1B, \ + PipelineDriverBase +from pippy.hf import PiPPyHFTracer +from pippy.microbatch import TensorChunkSpec +from pippy import split_on_size_threshold, split_into_equal_size +from transformers import AutoModelForSeq2SeqLM +from transformers import OPTModel, BloomModel +from PIL import Image +import requests +from transformers import AutoFeatureExtractor, RegNetModel +from transformers import OPTForCausalLM +import torch.distributed.rpc as rpc + +import torch +import transformers +from transformers import BloomForCausalLM, BloomTokenizerFast + +from ts.torch_handler.base_handler import BaseHandler + +logger = logging.getLogger(__name__) +logger.info("Transformers version %s", transformers.__version__) + +PIPPY_VERBOSITY = os.environ.get("PIPPY_VERBOSITY", "DEBUG") +TORCH_DTYPES = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, +} + +schedules = { + 'FillDrain': PipelineDriverFillDrain, + '1F1B': PipelineDriver1F1B, + 'Interleaved1F1B': PipelineDriverInterleaved1F1B, +} + +class TransformersSeqClassifierHandler(BaseHandler, ABC): + """ + Transformers handler class for sequence, token classification and question answering. + """ + + def __init__(self): + super(TransformersSeqClassifierHandler, self).__init__() + self.initialized = False + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + + options = rpc.TensorPipeRpcBackendOptions( + num_worker_threads=512, + rpc_timeout=1800 + # transports=None, + ) + + + # if args.cuda: + n_devs = torch.cuda.device_count() + print(f"n_devs={n_devs}") + dev_id = self.local_rank % n_devs + for i in range (self.world_size): + print(f"worker{i}, {dev_id}: {i % n_devs}") + options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) + + self.device = f"cuda:{dev_id}" + print( + f"rank = {self.local_rank} pid/device = " + f"{os.getpid()}/{self.device}" + ) + + rpc.init_rpc(f"worker{self.local_rank}", + rank=self.local_rank, + world_size=self.world_size, + rpc_backend_options=options) + + def initialize(self, ctx): + """In this initialize function, the BERT model is loaded and + the Layer Integrated Gradients Algorithm for Captum Explanations + is initialized here. + Args: + ctx (context): It is a JSON Object containing information + pertaining to the model artefacts parameters. + """ + # parser = argparse.ArgumentParser() + # args = parser.parse_args() + # args.world_size = 4 + # args.gspmd = 1 + #if self.local_rank != 0: + # return + + self.manifest = ctx.manifest + properties = ctx.system_properties + model_dir = properties.get("model_dir") + + # self.device = torch.device( + # "cuda:" + str(properties.get("gpu_id")) + # if torch.cuda.is_available() and properties.get("gpu_id") is not None + # else "cpu" + # ) + # Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode + # further setup config can be added. + with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref: + zip_ref.extractall(model_dir + "/model") + + # read configs for the mode, model_name, etc. from setup_config.json + setup_config_path = os.path.join(model_dir, "setup_config.json") + if os.path.isfile(setup_config_path): + with open(setup_config_path) as setup_config_file: + self.setup_config = json.load(setup_config_file) + else: + logger.warning("Missing the setup_config.json file.") + + torch.manual_seed(42) + replicate = 0 + schedule = list(schedules.keys())[0] + MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE if replicate else MultiUseParameterConfig.TRANSMIT + print(f'REPLICATE config: {replicate} -> {MULTI_USE_PARAM_CONFIG}') + print("Using schedule:", schedule) + + model = BloomModel.from_pretrained( + model_dir + "/model", use_cache=False) + + self.tokenizer = BloomTokenizerFast.from_pretrained( + model_dir + "/model", return_tensors="pt" + ) + + logger.info("********************* model loaded *************************", model_dir) + + # model = BloomModel.from_pretrained("bigscience/bloom-3b", use_cache=False) + + model_config = model.config + + model_config.use_cache = False # don't output `past_key_values` + model.eval() + + + split_policy = split_into_equal_size(self.world_size) + pp_ranks = [0,1,2,3] + all_worker_ranks = list(range(self.world_size)) + chunks = 1 + bs = 1 * chunks + seq_length = 16 + + + input_names = ['input_ids'] + sig = inspect.signature(model.forward) + concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} + + print('Instantiating model Pipeline') + model_init_start = time.time() + pipe_driver, stage_mode = pippy.all_compile( + model, + num_ranks=self.world_size, + num_chunks=chunks, + schedule="FillDrain", + split_policy=split_policy, + tracer=PiPPyHFTracer(), + concrete_args=concrete_args, + ) + # model_pipe = Pipe.from_tracing(self.model, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, + # output_loss_value_spec=None, split_policy=split_policy + # ) + + # model_pipe.defer_stage_init(self.device + self.local_rank) + + # pippy.utils.pp_group_barrier() + + # split_gm_children = list(model_pipe.split_gm.children()) + + # pipe_driver: PipelineDriverBase = schedules["FillDrain"](model_pipe, chunks, + # world_size=self.world_size, + # all_ranks=all_worker_ranks, + # ) + + self.model = pipe_driver + logger.info("Transformer model from path %s loaded successfully", model_dir) + + self.initialized = True + + + def preprocess(self, requests): + """Basic text preprocessing, based on the user's chocie of application mode. + Args: + requests (str): The Input data in the form of text is passed on to the preprocess + function. + Returns: + list : The preprocess function returns a list of Tensor for the size of the word tokens. + """ + input_ids_batch = None + attention_mask_batch = None + for idx, data in enumerate(requests): + input_text = data.get("data") + if input_text is None: + input_text = data.get("body") + if isinstance(input_text, (bytes, bytearray)): + input_text = input_text.decode("utf-8") + + max_length = self.setup_config["max_length"] + logger.info("Received text: '%s'", input_text) + + inputs = self.tokenizer.encode_plus( + input_text, + max_length=int(max_length), + pad_to_max_length=True, + add_special_tokens=True, + return_tensors="pt", + ) + + input_ids = inputs["input_ids"].to(self.device) + attention_mask = inputs["attention_mask"].to(self.device) + # making a batch out of the recieved requests + # attention masks are passed for cases where input tokens are padded. + if input_ids.shape is not None: + if input_ids_batch is None: + input_ids_batch = input_ids + attention_mask_batch = attention_mask + else: + input_ids_batch = torch.cat((input_ids_batch, input_ids), 0) + attention_mask_batch = torch.cat( + (attention_mask_batch, attention_mask), 0 + ) + return (input_ids_batch, attention_mask_batch) + + def inference(self, input_batch): + """Predict the class (or classes) of the received text using the + serialized transformers checkpoint. + Args: + input_batch (list): List of Text Tensors from the pre-process function is passed here + Returns: + list : It returns a list of the predicted value for the input text + """ + (input_ids_batch, _) = input_batch + inferences = [] + input_ids_batch = input_ids_batch.to(self.device) + model_input_dict = {} + model_input_dict["input_ids"]=input_ids_batch + # outputs = self.model.generate( + # input_ids_batch, do_sample=True, max_length=50, top_p=0.95, top_k=60 + # ) + # for i, _ in enumerate(outputs): + # inferences.append( + # self.tokenizer.decode(outputs[i], skip_special_tokens=True) + # ) + if self.local_rank==0: + output = self.model(**model_input_dict) + # rpc.shutdown() + print("************** here is the output",type(output)) + logger.info("Generated text: '%s'", inferences) + inferences.append(output) + print("Generated text", inferences) + return inferences + + def postprocess(self, inference_output): + """Post Process Function converts the predicted response into Torchserve readable format. + Args: + inference_output (list): It contains the predicted response of the input text. + Returns: + (list): Returns a list of the Predictions and Explanations. + """ + return inference_output + + def handle(self, data, context): + if self.local_rank != 0: + pass + start_time = time.time() + + self.context = context + metrics = self.context.metrics + + #run_pippy(self.initialize, context) + + is_profiler_enabled = os.environ.get("ENABLE_TORCH_PROFILER", None) + if is_profiler_enabled: + if PROFILER_AVAILABLE: + output, _ = self._infer_with_profiler(data=data) + else: + raise RuntimeError( + "Profiler is enabled but current version of torch does not support." + "Install torch>=1.8.1 to use profiler." + ) + else: + if self._is_describe(): + output = [self.describe_handle()] + else: + data_preprocess = self.preprocess(data) + + if not self._is_explain(): + output = self.inference(data_preprocess) + output = self.postprocess(output) + else: + output = self.explain_handle(data_preprocess, data) + + stop_time = time.time() + metrics.add_time( + "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" + ) + return output diff --git a/examples/Huggingface_Largemodels/pippy_pptp_handler.py b/examples/Huggingface_Largemodels/pippy_pptp_handler.py new file mode 100644 index 0000000000..ee5bb05b83 --- /dev/null +++ b/examples/Huggingface_Largemodels/pippy_pptp_handler.py @@ -0,0 +1,312 @@ +import json +import zipfile +from abc import ABC + +import inspect +import logging +import os +import time + +import pippy +import pippy.fx +from pippy import run_pippy +from pippy.IR import pipe_split +from pippy.IR import MultiUseParameterConfig, Pipe +from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverInterleaved1F1B, \ + PipelineDriverBase +from pippy.hf import PiPPyHFTracer +from pippy.microbatch import TensorChunkSpec +from pippy import split_on_size_threshold, split_into_equal_size +from transformers import AutoModelForSeq2SeqLM +from transformers import OPTModel, BloomModel +from PIL import Image +import requests +from transformers import AutoFeatureExtractor, RegNetModel +from transformers import OPTForCausalLM +import torch.distributed.rpc as rpc + +import torch +import transformers +from transformers import BloomForCausalLM, BloomTokenizerFast + + +from pippy import run_pippy +from pippy.IR import pipe_split + +from ts.torch_handler.base_handler import BaseHandler + +logger = logging.getLogger(__name__) +logger.info("Transformers version %s", transformers.__version__) + +PIPPY_VERBOSITY = os.environ.get("PIPPY_VERBOSITY", "DEBUG") +TORCH_DTYPES = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, +} + +schedules = { + 'FillDrain': PipelineDriverFillDrain, + '1F1B': PipelineDriver1F1B, + 'Interleaved1F1B': PipelineDriverInterleaved1F1B, +} + +class TransformersSeqClassifierHandler(BaseHandler, ABC): + """ + Transformers handler class for sequence, token classification and question answering. + """ + + def __init__(self): + super(TransformersSeqClassifierHandler, self).__init__() + self.initialized = False + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + + options = rpc.TensorPipeRpcBackendOptions( + num_worker_threads=512, + rpc_timeout=1800 + # transports=None, + ) + + + # if args.cuda: + n_devs = torch.cuda.device_count() + print(f"n_devs={n_devs}") + dev_id = self.local_rank % n_devs + for i in range (self.world_size): + print(f"worker{i}, {dev_id}: {i % n_devs}") + options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) + + self.device = f"cuda:{dev_id}" + print( + f"rank = {self.local_rank} pid/device = " + f"{os.getpid()}/{self.device}" + ) + + rpc.init_rpc(f"worker{self.local_rank}", + rank=self.local_rank, + world_size=self.world_size, + rpc_backend_options=options) + + def initialize(self, ctx): + """In this initialize function, the BERT model is loaded and + the Layer Integrated Gradients Algorithm for Captum Explanations + is initialized here. + Args: + ctx (context): It is a JSON Object containing information + pertaining to the model artefacts parameters. + """ + # parser = argparse.ArgumentParser() + # args = parser.parse_args() + # args.world_size = 4 + # args.gspmd = 1 + #if self.local_rank != 0: + # return + + self.manifest = ctx.manifest + properties = ctx.system_properties + model_dir = properties.get("model_dir") + + # self.device = torch.device( + # "cuda:" + str(properties.get("gpu_id")) + # if torch.cuda.is_available() and properties.get("gpu_id") is not None + # else "cpu" + # ) + # Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode + # further setup config can be added. + with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref: + zip_ref.extractall(model_dir + "/model") + + # read configs for the mode, model_name, etc. from setup_config.json + setup_config_path = os.path.join(model_dir, "setup_config.json") + if os.path.isfile(setup_config_path): + with open(setup_config_path) as setup_config_file: + self.setup_config = json.load(setup_config_file) + else: + logger.warning("Missing the setup_config.json file.") + + torch.manual_seed(42) + replicate = 0 + schedule = list(schedules.keys())[0] + MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE if replicate else MultiUseParameterConfig.TRANSMIT + print(f'REPLICATE config: {replicate} -> {MULTI_USE_PARAM_CONFIG}') + print("Using schedule:", schedule) + + model = BloomModel.from_pretrained( + model_dir + "/model", use_cache=False) + + self.tokenizer = BloomTokenizerFast.from_pretrained( + model_dir + "/model", return_tensors="pt" + ) + + logger.info("********************* model loaded *************************", model_dir) + + # model = BloomModel.from_pretrained("bigscience/bloom-3b", use_cache=False) + + model_config = model.config + + model_config.use_cache = False # don't output `past_key_values` + model.eval() + + + split_policy = split_into_equal_size(self.world_size) + pp_ranks = [0,1,2,3] + all_worker_ranks = list(range(self.world_size)) + chunks = 1 + bs = 1 * chunks + seq_length = 16 + + + input_names = ['input_ids'] + sig = inspect.signature(model.forward) + concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} + + print('Instantiating model Pipeline') + model_init_start = time.time() + pipe_driver, stage_mode = pippy.all_compile( + model, + num_ranks=self.world_size, + num_chunks=chunks, + schedule="FillDrain", + split_policy=split_policy, + tracer=PiPPyHFTracer(), + concrete_args=concrete_args, + ) + # model_pipe = Pipe.from_tracing(self.model, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, + # output_loss_value_spec=None, split_policy=split_policy + # ) + + # model_pipe.defer_stage_init(self.device + self.local_rank) + + # pippy.utils.pp_group_barrier() + + # split_gm_children = list(model_pipe.split_gm.children()) + + # pipe_driver: PipelineDriverBase = schedules["FillDrain"](model_pipe, chunks, + # world_size=self.world_size, + # all_ranks=all_worker_ranks, + # ) + + self.model = pipe_driver + logger.info("Transformer model from path %s loaded successfully", model_dir) + + self.initialized = True + + + def preprocess(self, requests): + """Basic text preprocessing, based on the user's chocie of application mode. + Args: + requests (str): The Input data in the form of text is passed on to the preprocess + function. + Returns: + list : The preprocess function returns a list of Tensor for the size of the word tokens. + """ + input_ids_batch = None + attention_mask_batch = None + for idx, data in enumerate(requests): + input_text = data.get("data") + if input_text is None: + input_text = data.get("body") + if isinstance(input_text, (bytes, bytearray)): + input_text = input_text.decode("utf-8") + + max_length = self.setup_config["max_length"] + logger.info("Received text: '%s'", input_text) + + inputs = self.tokenizer.encode_plus( + input_text, + max_length=int(max_length), + pad_to_max_length=True, + add_special_tokens=True, + return_tensors="pt", + ) + + input_ids = inputs["input_ids"].to(self.device) + attention_mask = inputs["attention_mask"].to(self.device) + # making a batch out of the recieved requests + # attention masks are passed for cases where input tokens are padded. + if input_ids.shape is not None: + if input_ids_batch is None: + input_ids_batch = input_ids + attention_mask_batch = attention_mask + else: + input_ids_batch = torch.cat((input_ids_batch, input_ids), 0) + attention_mask_batch = torch.cat( + (attention_mask_batch, attention_mask), 0 + ) + return (input_ids_batch, attention_mask_batch) + + def inference(self, input_batch): + """Predict the class (or classes) of the received text using the + serialized transformers checkpoint. + Args: + input_batch (list): List of Text Tensors from the pre-process function is passed here + Returns: + list : It returns a list of the predicted value for the input text + """ + (input_ids_batch, _) = input_batch + inferences = [] + input_ids_batch = input_ids_batch.to(self.device) + model_input_dict = {} + model_input_dict["input_ids"]=input_ids_batch + # outputs = self.model.generate( + # input_ids_batch, do_sample=True, max_length=50, top_p=0.95, top_k=60 + # ) + # for i, _ in enumerate(outputs): + # inferences.append( + # self.tokenizer.decode(outputs[i], skip_special_tokens=True) + # ) + if self.local_rank==0: + output = self.model(**model_input_dict) + # rpc.shutdown() + print("************** here is the output",type(output)) + logger.info("Generated text: '%s'", inferences) + inferences.append(output) + print("Generated text", inferences) + return inferences + + def postprocess(self, inference_output): + """Post Process Function converts the predicted response into Torchserve readable format. + Args: + inference_output (list): It contains the predicted response of the input text. + Returns: + (list): Returns a list of the Predictions and Explanations. + """ + return inference_output + + def handle(self, data, context): + if self.local_rank != 0: + pass + start_time = time.time() + + self.context = context + metrics = self.context.metrics + + #run_pippy(self.initialize, context) + + is_profiler_enabled = os.environ.get("ENABLE_TORCH_PROFILER", None) + if is_profiler_enabled: + if PROFILER_AVAILABLE: + output, _ = self._infer_with_profiler(data=data) + else: + raise RuntimeError( + "Profiler is enabled but current version of torch does not support." + "Install torch>=1.8.1 to use profiler." + ) + else: + if self._is_describe(): + output = [self.describe_handle()] + else: + data_preprocess = self.preprocess(data) + + if not self._is_explain(): + output = self.inference(data_preprocess) + output = self.postprocess(output) + else: + output = self.explain_handle(data_preprocess, data) + + stop_time = time.time() + metrics.add_time( + "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" + ) + return output diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index 428845b366..ae46029991 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -6,6 +6,7 @@ import java.nio.file.FileAlreadyExistsException; import java.nio.file.Files; import java.util.List; +import java.util.Map; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.pytorch.serve.archive.DownloadArchiveException; @@ -187,11 +188,13 @@ public ModelConfig getModelConfig() { try { File configFile = new File(modelDir.getAbsolutePath(), manifest.getModel().getConfigFile()); - this.modelConfig = ArchiveUtils.readYamlFile(configFile, ModelConfig.class); + Map modelConfigMap = ArchiveUtils.readYamlFile(configFile); + this.modelConfig = ModelConfig.build(modelConfigMap); } catch (InvalidModelException | IOException e) { logger.error( "Failed to parse model config file {}", - manifest.getModel().getConfigFile(), e); + manifest.getModel().getConfigFile(), + e); } } return this.modelConfig; diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index 4abdd2a960..87ec11fd02 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -1,8 +1,11 @@ package org.pytorch.serve.archive.model; -import java.util.ArrayList; import java.util.Arrays; +import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; public class ModelConfig { private int minWorkers; @@ -10,13 +13,50 @@ public class ModelConfig { private int batchSize; private int maxBatchDelay; private int responseTimeout; - private String deviceType; - private CoreType coreType = CoreType.NONE; - private ArrayList deviceIds; + private DeviceType deviceType = DeviceType.NONE; + private List deviceIds; private int parallelLevel = 1; - private String parallelMode; private ParallelType parallelType = ParallelType.NONE; + public static ModelConfig build(Map yamlMap) { + ModelConfig modelConfig = new ModelConfig(); + yamlMap.forEach( + (k, v) -> { + switch (k) { + case "minWorkers": + modelConfig.setMinWorkers((int) v); + break; + case "maxWorkers": + modelConfig.setMaxWorkers((int) v); + break; + case "batchSize": + modelConfig.setBatchSize((int) v); + break; + case "maxBatchDelay": + modelConfig.setMaxBatchDelay((int) v); + break; + case "responseTimeout": + modelConfig.setResponseTimeout((int) v); + break; + case "deviceType": + modelConfig.setDeviceType((String) v); + break; + case "parallelLevel": + modelConfig.setParallelLevel((int) v); + break; + case "parallelType": + modelConfig.setParallelMode((String) v); + break; + case "deviceIds": + modelConfig.setDeviceIds(v); + break; + default: + break; + } + }); + return modelConfig; + } + public int getMinWorkers() { return minWorkers; } @@ -57,12 +97,16 @@ public void setResponseTimeout(int responseTimeout) { this.responseTimeout = responseTimeout; } - public ArrayList getDeviceIds() { + public List getDeviceIds() { return deviceIds; } - public void setDeviceIds(ArrayList deviceIds) { - this.deviceIds = deviceIds; + public void setDeviceIds(Object deviceIds) { + this.deviceIds = + Stream.of(deviceIds) + .map(Object::toString) + .map(Integer::parseInt) + .collect(Collectors.toList()); } public int getParallelLevel() { @@ -74,31 +118,21 @@ public void setParallelLevel(int parallelLevel) { } public void setParallelMode(String parallelMode) { - this.parallelMode = parallelMode; this.parallelType = ParallelType.get(parallelMode).get(); } - public String getParallelMode() { - return this.parallelMode; - } - public ParallelType getParallelType() { return this.parallelType; } public void setDeviceType(String deviceType) { - this.deviceType = deviceType; - this.coreType = CoreType.get(deviceType).get(); + this.deviceType = DeviceType.get(deviceType).get(); } - public String getDeviceType() { + public DeviceType getDeviceType() { return deviceType; } - public CoreType getCoreType() { - return coreType; - } - public enum ParallelType { NONE(""), PP("pp"), @@ -122,7 +156,7 @@ public static Optional get(String parallelType) { } } - public enum CoreType { + public enum DeviceType { NONE(""), CPU("cpu"), GPU("gpu"), @@ -130,17 +164,17 @@ public enum CoreType { private String type; - CoreType(String type) { + DeviceType(String type) { this.type = type; } - public String getCoreType() { + public String getDeviceType() { return type; } - public static Optional get(String coreType) { - return Arrays.stream(CoreType.values()) - .filter(t -> t.type.equals(coreType)) + public static Optional get(String deviceType) { + return Arrays.stream(DeviceType.values()) + .filter(t -> t.type.equals(deviceType)) .findFirst(); } } diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java index 5fd0817381..82c4681dd6 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java @@ -13,6 +13,7 @@ import java.nio.file.FileAlreadyExistsException; import java.nio.file.Files; import java.util.List; +import java.util.Map; import java.util.regex.Pattern; import org.apache.commons.io.FileUtils; import org.pytorch.serve.archive.DownloadArchiveException; @@ -55,6 +56,19 @@ public static T readYamlFile(File file, Class type) } } + public static Map readYamlFile(File file) + throws InvalidModelException, IOException { + Yaml yaml = new Yaml(); + try (Reader r = + new InputStreamReader( + Files.newInputStream(file.toPath()), StandardCharsets.UTF_8)) { + + return yaml.load(r); + } catch (YAMLException e) { + throw new InvalidModelException("Failed to parse model config yaml file.", e); + } + } + public static boolean validateURL(List allowedUrls, String url) throws InvalidArchiveURLException { boolean patternMatch = false; diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index d313ec15a0..46333efe29 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -2,8 +2,8 @@ import com.google.gson.JsonObject; import java.io.File; -import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; @@ -42,11 +42,11 @@ public class Model { private int maxBatchDelay; private int parallelLevel = 1; private ModelConfig.ParallelType parallelType = ModelConfig.ParallelType.NONE; - private ModelConfig.CoreType coreType = + private ModelConfig.DeviceType deviceType = ConfigManager.getInstance().getNumberOfGpu() > 0 - ? ModelConfig.CoreType.GPU - : ModelConfig.CoreType.CPU; - private ArrayList coreIds; + ? ModelConfig.DeviceType.GPU + : ModelConfig.DeviceType.CPU; + private List coreIds; private int numCores; private ReentrantLock lock; private int responseTimeout; @@ -69,12 +69,12 @@ public Model(ModelArchive modelArchive, int queueSize) { parallelLevel = modelArchive.getModelConfig().getParallelLevel(); parallelType = modelArchive.getModelConfig().getParallelType(); } - if (modelArchive.getModelConfig().getCoreType() != ModelConfig.CoreType.NONE) { - coreType = - (modelArchive.getModelConfig().getCoreType() == ModelConfig.CoreType.GPU + if (modelArchive.getModelConfig().getDeviceType() != ModelConfig.DeviceType.NONE) { + deviceType = + (modelArchive.getModelConfig().getDeviceType() == ModelConfig.DeviceType.GPU && ConfigManager.getInstance().getNumberOfGpu() > 0) - ? ModelConfig.CoreType.GPU - : coreType; + ? ModelConfig.DeviceType.GPU + : deviceType; } coreIds = modelArchive.getModelConfig().getDeviceIds(); } else { @@ -291,11 +291,11 @@ public void setResponseTimeout(int responseTimeout) { this.responseTimeout = responseTimeout; } - public ArrayList getCoreIds() { + public List getCoreIds() { return this.coreIds; } - public void setCoreIdsIds(ArrayList coreIds) { + public void setCoreIdsIds(List coreIds) { Collections.copy(this.coreIds, coreIds); } @@ -311,8 +311,8 @@ public ModelConfig.ParallelType getParallelType() { return this.parallelType; } - public ModelConfig.CoreType getCoreType() { - return this.coreType; + public ModelConfig.DeviceType getDeviceType() { + return this.deviceType; } public int getNumCores() { diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java index 63e26ef20b..e0f0f7b99b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java @@ -426,7 +426,7 @@ public CompletableFuture updateModel( throw new ModelVersionNotFoundException( "Model version: " + versionId + " does not exist for model: " + modelName); } - if (model.getParallelLevel() > 1 && model.getCoreType() == ModelConfig.CoreType.GPU) { + if (model.getParallelLevel() > 1 && model.getDeviceType() == ModelConfig.DeviceType.GPU) { /** * Current capacity check for LMI is based on single node. TODO: multiple nodes check * will be based on --proc-per-node + numCores. diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index 088d84cef5..d48d85bc5b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -14,7 +14,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; - import org.pytorch.serve.archive.model.ModelConfig; import org.pytorch.serve.snapshot.SnapshotManager; import org.pytorch.serve.util.ConfigManager; @@ -198,7 +197,7 @@ private void addThreads( for (int i = 0; i < count; ++i) { int gpuId = -1; - if (maxGpu > 0 && model.getCoreType() == ModelConfig.CoreType.GPU) { + if (maxGpu > 0 && model.getDeviceType() == ModelConfig.DeviceType.GPU) { if (model.getParallelLevel() > 1) { gpuId = parallelGpuIdx; parallelGpuIdx += model.getParallelLevel(); diff --git a/ts/context.py b/ts/context.py index 97e79ae84b..16b100f6f5 100644 --- a/ts/context.py +++ b/ts/context.py @@ -21,6 +21,7 @@ def __init__( mms_version, limit_max_image_pixels=True, metrics=None, + model_yaml_config=None ): self.model_name = model_name self.manifest = manifest @@ -37,6 +38,7 @@ def __init__( self._metrics = None self._limit_max_image_pixels = True self.metrics = metrics + self.model_yaml_config = model_yaml_config @property def system_properties(self): diff --git a/ts/service.py b/ts/service.py index c20fd79bed..dacbb4be47 100644 --- a/ts/service.py +++ b/ts/service.py @@ -2,6 +2,7 @@ CustomService class definitions """ import logging +import os import time from builtins import str @@ -9,6 +10,7 @@ from ts.context import Context, RequestProcessor from ts.protocol.otf_message_handler import create_predict_response from ts.utils.util import PredictionException +from ts.utils.util import get_yaml_config PREDICTION_METRIC = "PredictionTime" logger = logging.getLogger(__name__) @@ -30,6 +32,14 @@ def __init__( limit_max_image_pixels=True, metrics_cache=None, ): + model_yaml_config = None + model_yaml_config_file = None + if manifest is not None and "configFile" in manifest["model"]: + model_yaml_config_file = manifest["model"]["configFile"] + if model_yaml_config_file is not None: + model_yaml_config_file_path = os.path.join(model_dir, model_yaml_config_file) + model_yaml_config = get_yaml_config(model_yaml_config_file_path) + self._context = Context( model_name, model_dir, @@ -39,6 +49,7 @@ def __init__( ts.__version__, limit_max_image_pixels, metrics_cache, + model_yaml_config ) self._entry_point = entry_point diff --git a/ts/utils/util.py b/ts/utils/util.py index 629f274008..0fd346d5b2 100644 --- a/ts/utils/util.py +++ b/ts/utils/util.py @@ -8,7 +8,7 @@ import logging import os import re - +import yaml class PT2Backend(str, enum.Enum): EAGER = "eager" @@ -135,6 +135,11 @@ def map_class_to_label(probs, mapping=None, lbl_classes=None): return results +def get_yaml_config(yaml_file_path): + config = None + with open(yaml_file_path, 'r'): + config = yaml.safe_load(yaml_file_path) + return config class PredictionException(Exception): def __init__(self, message, error_code=500): From fc115bc9a75694a3946c1cd1d6b198dae46b31fc Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 13 Mar 2023 10:08:47 -0700 Subject: [PATCH 24/47] add tp handler init --- .../pippy_pptp_handler.py | 191 ++++++++++++------ 1 file changed, 128 insertions(+), 63 deletions(-) diff --git a/examples/Huggingface_Largemodels/pippy_pptp_handler.py b/examples/Huggingface_Largemodels/pippy_pptp_handler.py index ee5bb05b83..af10eb67df 100644 --- a/examples/Huggingface_Largemodels/pippy_pptp_handler.py +++ b/examples/Huggingface_Largemodels/pippy_pptp_handler.py @@ -7,32 +7,27 @@ import os import time +import torch import pippy import pippy.fx -from pippy import run_pippy -from pippy.IR import pipe_split +from torch.distributed._tensor import ( + DeviceMesh, +) +from torch.distributed.tensor.parallel import ( + PairwiseParallel, + parallelize_module, +) + + from pippy.IR import MultiUseParameterConfig, Pipe from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverInterleaved1F1B, \ PipelineDriverBase -from pippy.hf import PiPPyHFTracer -from pippy.microbatch import TensorChunkSpec from pippy import split_on_size_threshold, split_into_equal_size -from transformers import AutoModelForSeq2SeqLM from transformers import OPTModel, BloomModel -from PIL import Image -import requests -from transformers import AutoFeatureExtractor, RegNetModel -from transformers import OPTForCausalLM import torch.distributed.rpc as rpc -import torch import transformers from transformers import BloomForCausalLM, BloomTokenizerFast - - -from pippy import run_pippy -from pippy.IR import pipe_split - from ts.torch_handler.base_handler import BaseHandler logger = logging.getLogger(__name__) @@ -60,33 +55,11 @@ def __init__(self): super(TransformersSeqClassifierHandler, self).__init__() self.initialized = False self.local_rank = int(os.environ["LOCAL_RANK"]) + self.rank = int(os.environ["RANK"]) self.world_size = int(os.environ["WORLD_SIZE"]) - - options = rpc.TensorPipeRpcBackendOptions( - num_worker_threads=512, - rpc_timeout=1800 - # transports=None, - ) - - - # if args.cuda: - n_devs = torch.cuda.device_count() - print(f"n_devs={n_devs}") - dev_id = self.local_rank % n_devs - for i in range (self.world_size): - print(f"worker{i}, {dev_id}: {i % n_devs}") - options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) - - self.device = f"cuda:{dev_id}" - print( - f"rank = {self.local_rank} pid/device = " - f"{os.getpid()}/{self.device}" - ) - - rpc.init_rpc(f"worker{self.local_rank}", - rank=self.local_rank, - world_size=self.world_size, - rpc_backend_options=options) + self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + self.pp_rank = 0 + self.pp_ranks = None def initialize(self, ctx): """In this initialize function, the BERT model is loaded and @@ -96,12 +69,77 @@ def initialize(self, ctx): ctx (context): It is a JSON Object containing information pertaining to the model artefacts parameters. """ - # parser = argparse.ArgumentParser() - # args = parser.parse_args() - # args.world_size = 4 - # args.gspmd = 1 - #if self.local_rank != 0: - # return + pp_group_size = self.world_size + num_worker_threads = 512 + rpc_timeout = 1800 + if ctx.model_yaml_config is not None: + if ctx.system_properties.get("gpu_id") != -1 \ + and ctx.model_yaml_config["deviceIds"] is not None: + device_ids = ','.join(str(e) for e in ctx.model_yaml_config["deviceIds"][int(ctx.system_properties.get("gpu_id")):int(ctx.system_properties.get("gpu_id"))+self.world_size+1]) + os.environ["CUDA_VISIBLE_DEVICE"] = device_ids + + if ctx.model_yaml_config[pippy] is not None: + if ctx.model_yaml_config["pippy"]["pp_group_size"] is not None \ + and self.world_size % int(ctx.model_yaml_config["pippy"]["pp_group_size"]) == 0: + pp_group_size = int(ctx.model_yaml_config["pippy"]["pp_group_size"]) + + if ctx.model_yaml_config["pippy"]["num_worker_threads"] is not None: + num_worker_threads = int(ctx.model_yaml_config["pippy"]["num_worker_threads"]) + + if ctx.model_yaml_config["pippy"]["rpc_timeout"] is not None: + rpc_timeout = int(ctx.model_yaml_config["pippy"]["rpc_timeout"]) + + if ctx.system_properties.get("gpu_id") != -1 and os.environ["CUDA_VISIBLE_DEVICE"] is None: + os.environ["CUDA_VISIBLE_DEVICE"] = ','.join(str(e) for e in range(self.local_rank)) + + options = rpc.TensorPipeRpcBackendOptions( + num_worker_threads, + rpc_timeout + ) + device_type = "cpu" + if int(ctx.system_properties.get("gpu_id")) != -1: + device_type = "cuda" + n_devs = torch.cuda.device_count() + dev_id = self.local_rank % n_devs + for i in range(self.world_size): + logging.info(f"worker{i}, {dev_id}: {i % n_devs}") + options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) + + self.device = f"cuda:{dev_id}" + logging.info( + f"rank = {self.local_rank} pid/device = " + f"{os.getpid()}/{self.device}" + ) + else: + self.device = "cpu" + rpc.init_rpc(f"worker{self.rank}", + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options) + + tp_group_size = self.world_size // pp_group_size + dp_group_size = self.world_size // pp_group_size + + logging.info( + f"[PiPPy] World size: {self.world_size}, " + f"DP group size: {dp_group_size}, " + f"PP group size: {pp_group_size}" + ) + pp_ranks_per_dp_group = [ + [i + rank for i in range(pp_group_size)] + for rank in range(dp_group_size) + ] + self.pp_ranks = pp_ranks_per_dp_group[self.rank % dp_group_size] + self.pp_rank = self.rank // tp_group_size + logging.info(f"Global rank {self.rank}, pipeline: {self.pp_ranks}, my rank in pipe: {self.pp_rank}") + + d_hid = 256 + batch_size_per_chunk = 8 + chunks = pp_group_size + #inp_size = [chunks * batch_size_per_chunk, d_hid] + # Ensure all tp ranks have same input. + #torch.manual_seed(0) + #inp = torch.rand(*inp_size, device=device_type) self.manifest = ctx.manifest properties = ctx.system_properties @@ -125,7 +163,7 @@ def initialize(self, ctx): else: logger.warning("Missing the setup_config.json file.") - torch.manual_seed(42) + replicate = 0 schedule = list(schedules.keys())[0] MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE if replicate else MultiUseParameterConfig.TRANSMIT @@ -150,28 +188,53 @@ def initialize(self, ctx): split_policy = split_into_equal_size(self.world_size) - pp_ranks = [0,1,2,3] - all_worker_ranks = list(range(self.world_size)) - chunks = 1 - bs = 1 * chunks - seq_length = 16 input_names = ['input_ids'] sig = inspect.signature(model.forward) concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} + torch.manual_seed(0) + ec_tp = model(d_hid) + ec_tp.to(self.device) + start_idx = 0 + device_mesh = DeviceMesh( + device_type, + list(range(start_idx, start_idx + tp_group_size)), + ) + logging.info(f"Rank {self.rank} calling parallelize_module with {device_mesh}") + parallelize_module(ec_tp, device_mesh, PairwiseParallel()) + logging.info(f"Rank {self.rank} sharding complete") + print('Instantiating model Pipeline') model_init_start = time.time() - pipe_driver, stage_mode = pippy.all_compile( + # Get: + # - pipeline driver (for pipeline head rank) + # - stage submodule (for all ranks) + pipe_driver, submod = pippy.all_compile( model, - num_ranks=self.world_size, - num_chunks=chunks, - schedule="FillDrain", - split_policy=split_policy, - tracer=PiPPyHFTracer(), - concrete_args=concrete_args, + pp_group_size, + chunks, + ranks=self.pp_ranks, ) + + # Create TP device mesh + my_device_mesh = None + for stage in range(pp_group_size): + start_rank = stage * tp_group_size + tp_ranks = list(range(start_rank, start_rank + tp_group_size)) + tp_device_mesh = DeviceMesh( + device_type, + tp_ranks, + ) + if stage == self.pp_rank: + my_device_mesh = tp_device_mesh + + # Tensor parallelize submodules + print(f"Rank {self.rank} calling parallelize_module with {my_device_mesh}") + parallelize_module(submod, my_device_mesh, PairwiseParallel()) + + # model_pipe = Pipe.from_tracing(self.model, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, # output_loss_value_spec=None, split_policy=split_policy # ) @@ -256,9 +319,11 @@ def inference(self, input_batch): # inferences.append( # self.tokenizer.decode(outputs[i], skip_special_tokens=True) # ) - if self.local_rank==0: - output = self.model(**model_input_dict) - # rpc.shutdown() + #if self.pp_rank == 0: + # print(f"Rank {self.rank} Instantiated pipeline with ranks {self.pp_ranks}") + output = self.model(**model_input_dict) + + print("************** here is the output",type(output)) logger.info("Generated text: '%s'", inferences) inferences.append(output) From 6d315a070329fa8a935225f76610caa063d971d9 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 14 Mar 2023 17:35:20 -0700 Subject: [PATCH 25/47] fix model config parser --- ts/service.py | 11 +++++------ ts/utils/util.py | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/ts/service.py b/ts/service.py index dacbb4be47..724e7b2bb9 100644 --- a/ts/service.py +++ b/ts/service.py @@ -33,12 +33,11 @@ def __init__( metrics_cache=None, ): model_yaml_config = None - model_yaml_config_file = None - if manifest is not None and "configFile" in manifest["model"]: - model_yaml_config_file = manifest["model"]["configFile"] - if model_yaml_config_file is not None: - model_yaml_config_file_path = os.path.join(model_dir, model_yaml_config_file) - model_yaml_config = get_yaml_config(model_yaml_config_file_path) + if manifest is not None and "model" in manifest: + model = manifest["model"] + if "configFile" in model: + model_yaml_config_file = model["configFile"] + model_yaml_config = get_yaml_config(os.path.join(model_dir, model_yaml_config_file)) self._context = Context( model_name, diff --git a/ts/utils/util.py b/ts/utils/util.py index 0fd346d5b2..0a5f914a95 100644 --- a/ts/utils/util.py +++ b/ts/utils/util.py @@ -137,8 +137,8 @@ def map_class_to_label(probs, mapping=None, lbl_classes=None): def get_yaml_config(yaml_file_path): config = None - with open(yaml_file_path, 'r'): - config = yaml.safe_load(yaml_file_path) + with open(yaml_file_path, 'r') as file: + config = yaml.safe_load(file) return config class PredictionException(Exception): From 71bcaccdb7712b0c70a0da7ff3abec489594505f Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 15 Mar 2023 09:12:56 -0700 Subject: [PATCH 26/47] update pippy_pptp_handler --- .../Huggingface_Largemodels/model-config.yaml | 16 +- .../Huggingface_Largemodels/pippy_handler.py | 309 ------------------ .../pippy_pptp_handler.py | 15 +- 3 files changed, 18 insertions(+), 322 deletions(-) delete mode 100644 examples/Huggingface_Largemodels/pippy_handler.py diff --git a/examples/Huggingface_Largemodels/model-config.yaml b/examples/Huggingface_Largemodels/model-config.yaml index e3dcf74c27..29c1314de0 100644 --- a/examples/Huggingface_Largemodels/model-config.yaml +++ b/examples/Huggingface_Largemodels/model-config.yaml @@ -1,15 +1,21 @@ +# TS Frontend parameters minWorkers: 1 maxWorkers: 1 +batchSize: 1 maxBatchDelay: 100 responseTimeout: 120 -coreType: cpu # cpu, gpu, neuron -coreIds: [0,1,2.3] # core index for gpu, neuron -parallelLevel: 4 -parallelType: pp # pp: pipeline parallel; tp: tensor parallel; tp+pp +deviceType: cpu # cpu, gpu, neuron +deviceIds: [0,1,2.3] # device index for gpu, neuron +parallelLevel: 4 # rpc world size +parallelType: pp # pp: pipeline parallel; pptp: tensor+pipeline parallel torchrun: - logLevel: INFO + max_restarts: 3 + +# TS backend parameters pippy: rpc_timeout: 1800 + pp_group_size: 4 # pipeline parallel size, tp_group_size = world size / pp_group_size + diff --git a/examples/Huggingface_Largemodels/pippy_handler.py b/examples/Huggingface_Largemodels/pippy_handler.py deleted file mode 100644 index b726e4a8a8..0000000000 --- a/examples/Huggingface_Largemodels/pippy_handler.py +++ /dev/null @@ -1,309 +0,0 @@ -import json -import logging -import os -import zipfile -from abc import ABC - -import argparse -import inspect -import logging -import os -import time - -import pippy.fx -#from pippy import run_pippy -from pippy.IR import MultiUseParameterConfig, Pipe -from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverInterleaved1F1B, \ - PipelineDriverBase -from pippy.hf import PiPPyHFTracer -from pippy.microbatch import TensorChunkSpec -from pippy import split_on_size_threshold, split_into_equal_size -from transformers import AutoModelForSeq2SeqLM -from transformers import OPTModel, BloomModel -from PIL import Image -import requests -from transformers import AutoFeatureExtractor, RegNetModel -from transformers import OPTForCausalLM -import torch.distributed.rpc as rpc - -import torch -import transformers -from transformers import BloomForCausalLM, BloomTokenizerFast - -from ts.torch_handler.base_handler import BaseHandler - -logger = logging.getLogger(__name__) -logger.info("Transformers version %s", transformers.__version__) - -PIPPY_VERBOSITY = os.environ.get("PIPPY_VERBOSITY", "DEBUG") -TORCH_DTYPES = { - "float16": torch.float16, - "float32": torch.float32, - "float64": torch.float64, -} - -schedules = { - 'FillDrain': PipelineDriverFillDrain, - '1F1B': PipelineDriver1F1B, - 'Interleaved1F1B': PipelineDriverInterleaved1F1B, -} - -class TransformersSeqClassifierHandler(BaseHandler, ABC): - """ - Transformers handler class for sequence, token classification and question answering. - """ - - def __init__(self): - super(TransformersSeqClassifierHandler, self).__init__() - self.initialized = False - self.local_rank = int(os.environ["LOCAL_RANK"]) - self.world_size = int(os.environ["WORLD_SIZE"]) - - options = rpc.TensorPipeRpcBackendOptions( - num_worker_threads=512, - rpc_timeout=1800 - # transports=None, - ) - - - # if args.cuda: - n_devs = torch.cuda.device_count() - print(f"n_devs={n_devs}") - dev_id = self.local_rank % n_devs - for i in range (self.world_size): - print(f"worker{i}, {dev_id}: {i % n_devs}") - options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) - - self.device = f"cuda:{dev_id}" - print( - f"rank = {self.local_rank} pid/device = " - f"{os.getpid()}/{self.device}" - ) - - rpc.init_rpc(f"worker{self.local_rank}", - rank=self.local_rank, - world_size=self.world_size, - rpc_backend_options=options) - - def initialize(self, ctx): - """In this initialize function, the BERT model is loaded and - the Layer Integrated Gradients Algorithm for Captum Explanations - is initialized here. - Args: - ctx (context): It is a JSON Object containing information - pertaining to the model artefacts parameters. - """ - # parser = argparse.ArgumentParser() - # args = parser.parse_args() - # args.world_size = 4 - # args.gspmd = 1 - #if self.local_rank != 0: - # return - - self.manifest = ctx.manifest - properties = ctx.system_properties - model_dir = properties.get("model_dir") - - # self.device = torch.device( - # "cuda:" + str(properties.get("gpu_id")) - # if torch.cuda.is_available() and properties.get("gpu_id") is not None - # else "cpu" - # ) - # Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode - # further setup config can be added. - with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref: - zip_ref.extractall(model_dir + "/model") - - # read configs for the mode, model_name, etc. from setup_config.json - setup_config_path = os.path.join(model_dir, "setup_config.json") - if os.path.isfile(setup_config_path): - with open(setup_config_path) as setup_config_file: - self.setup_config = json.load(setup_config_file) - else: - logger.warning("Missing the setup_config.json file.") - - torch.manual_seed(42) - replicate = 0 - schedule = list(schedules.keys())[0] - MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE if replicate else MultiUseParameterConfig.TRANSMIT - print(f'REPLICATE config: {replicate} -> {MULTI_USE_PARAM_CONFIG}') - print("Using schedule:", schedule) - - model = BloomModel.from_pretrained( - model_dir + "/model", use_cache=False) - - self.tokenizer = BloomTokenizerFast.from_pretrained( - model_dir + "/model", return_tensors="pt" - ) - - logger.info("********************* model loaded *************************", model_dir) - - # model = BloomModel.from_pretrained("bigscience/bloom-3b", use_cache=False) - - model_config = model.config - - model_config.use_cache = False # don't output `past_key_values` - model.eval() - - - split_policy = split_into_equal_size(self.world_size) - pp_ranks = [0,1,2,3] - all_worker_ranks = list(range(self.world_size)) - chunks = 1 - bs = 1 * chunks - seq_length = 16 - - - input_names = ['input_ids'] - sig = inspect.signature(model.forward) - concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} - - print('Instantiating model Pipeline') - model_init_start = time.time() - pipe_driver, stage_mode = pippy.all_compile( - model, - num_ranks=self.world_size, - num_chunks=chunks, - schedule="FillDrain", - split_policy=split_policy, - tracer=PiPPyHFTracer(), - concrete_args=concrete_args, - ) - # model_pipe = Pipe.from_tracing(self.model, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, - # output_loss_value_spec=None, split_policy=split_policy - # ) - - # model_pipe.defer_stage_init(self.device + self.local_rank) - - # pippy.utils.pp_group_barrier() - - # split_gm_children = list(model_pipe.split_gm.children()) - - # pipe_driver: PipelineDriverBase = schedules["FillDrain"](model_pipe, chunks, - # world_size=self.world_size, - # all_ranks=all_worker_ranks, - # ) - - self.model = pipe_driver - logger.info("Transformer model from path %s loaded successfully", model_dir) - - self.initialized = True - - - def preprocess(self, requests): - """Basic text preprocessing, based on the user's chocie of application mode. - Args: - requests (str): The Input data in the form of text is passed on to the preprocess - function. - Returns: - list : The preprocess function returns a list of Tensor for the size of the word tokens. - """ - input_ids_batch = None - attention_mask_batch = None - for idx, data in enumerate(requests): - input_text = data.get("data") - if input_text is None: - input_text = data.get("body") - if isinstance(input_text, (bytes, bytearray)): - input_text = input_text.decode("utf-8") - - max_length = self.setup_config["max_length"] - logger.info("Received text: '%s'", input_text) - - inputs = self.tokenizer.encode_plus( - input_text, - max_length=int(max_length), - pad_to_max_length=True, - add_special_tokens=True, - return_tensors="pt", - ) - - input_ids = inputs["input_ids"].to(self.device) - attention_mask = inputs["attention_mask"].to(self.device) - # making a batch out of the recieved requests - # attention masks are passed for cases where input tokens are padded. - if input_ids.shape is not None: - if input_ids_batch is None: - input_ids_batch = input_ids - attention_mask_batch = attention_mask - else: - input_ids_batch = torch.cat((input_ids_batch, input_ids), 0) - attention_mask_batch = torch.cat( - (attention_mask_batch, attention_mask), 0 - ) - return (input_ids_batch, attention_mask_batch) - - def inference(self, input_batch): - """Predict the class (or classes) of the received text using the - serialized transformers checkpoint. - Args: - input_batch (list): List of Text Tensors from the pre-process function is passed here - Returns: - list : It returns a list of the predicted value for the input text - """ - (input_ids_batch, _) = input_batch - inferences = [] - input_ids_batch = input_ids_batch.to(self.device) - model_input_dict = {} - model_input_dict["input_ids"]=input_ids_batch - # outputs = self.model.generate( - # input_ids_batch, do_sample=True, max_length=50, top_p=0.95, top_k=60 - # ) - # for i, _ in enumerate(outputs): - # inferences.append( - # self.tokenizer.decode(outputs[i], skip_special_tokens=True) - # ) - if self.local_rank==0: - output = self.model(**model_input_dict) - # rpc.shutdown() - print("************** here is the output",type(output)) - logger.info("Generated text: '%s'", inferences) - inferences.append(output) - print("Generated text", inferences) - return inferences - - def postprocess(self, inference_output): - """Post Process Function converts the predicted response into Torchserve readable format. - Args: - inference_output (list): It contains the predicted response of the input text. - Returns: - (list): Returns a list of the Predictions and Explanations. - """ - return inference_output - - def handle(self, data, context): - if self.local_rank != 0: - pass - start_time = time.time() - - self.context = context - metrics = self.context.metrics - - #run_pippy(self.initialize, context) - - is_profiler_enabled = os.environ.get("ENABLE_TORCH_PROFILER", None) - if is_profiler_enabled: - if PROFILER_AVAILABLE: - output, _ = self._infer_with_profiler(data=data) - else: - raise RuntimeError( - "Profiler is enabled but current version of torch does not support." - "Install torch>=1.8.1 to use profiler." - ) - else: - if self._is_describe(): - output = [self.describe_handle()] - else: - data_preprocess = self.preprocess(data) - - if not self._is_explain(): - output = self.inference(data_preprocess) - output = self.postprocess(output) - else: - output = self.explain_handle(data_preprocess, data) - - stop_time = time.time() - metrics.add_time( - "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" - ) - return output diff --git a/examples/Huggingface_Largemodels/pippy_pptp_handler.py b/examples/Huggingface_Largemodels/pippy_pptp_handler.py index af10eb67df..7776aee429 100644 --- a/examples/Huggingface_Largemodels/pippy_pptp_handler.py +++ b/examples/Huggingface_Largemodels/pippy_pptp_handler.py @@ -18,7 +18,6 @@ parallelize_module, ) - from pippy.IR import MultiUseParameterConfig, Pipe from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverInterleaved1F1B, \ PipelineDriverBase @@ -74,23 +73,23 @@ def initialize(self, ctx): rpc_timeout = 1800 if ctx.model_yaml_config is not None: if ctx.system_properties.get("gpu_id") != -1 \ - and ctx.model_yaml_config["deviceIds"] is not None: - device_ids = ','.join(str(e) for e in ctx.model_yaml_config["deviceIds"][int(ctx.system_properties.get("gpu_id")):int(ctx.system_properties.get("gpu_id"))+self.world_size+1]) + and "deviceIds" in ctx.model_yaml_config: + device_ids = ','.join(str(e) for e in ctx.model_yaml_config["deviceIds"][int(ctx.system_properties.get("gpu_id")):int(ctx.system_properties.get("gpu_id"))+self.loca_world_size+1]) os.environ["CUDA_VISIBLE_DEVICE"] = device_ids - if ctx.model_yaml_config[pippy] is not None: - if ctx.model_yaml_config["pippy"]["pp_group_size"] is not None \ + if "pippy" in ctx.model_yaml_config: + if "pp_group_size" in ctx.model_yaml_config["pippy"] \ and self.world_size % int(ctx.model_yaml_config["pippy"]["pp_group_size"]) == 0: pp_group_size = int(ctx.model_yaml_config["pippy"]["pp_group_size"]) - if ctx.model_yaml_config["pippy"]["num_worker_threads"] is not None: + if "num_worker_threads" in ctx.model_yaml_config["pippy"]: num_worker_threads = int(ctx.model_yaml_config["pippy"]["num_worker_threads"]) - if ctx.model_yaml_config["pippy"]["rpc_timeout"] is not None: + if "rpc_timeout" in ctx.model_yaml_config["pippy"]: rpc_timeout = int(ctx.model_yaml_config["pippy"]["rpc_timeout"]) if ctx.system_properties.get("gpu_id") != -1 and os.environ["CUDA_VISIBLE_DEVICE"] is None: - os.environ["CUDA_VISIBLE_DEVICE"] = ','.join(str(e) for e in range(self.local_rank)) + os.environ["CUDA_VISIBLE_DEVICE"] = ','.join(str(e) for e in range(int(ctx.system_properties.get("gpu_id")), int(ctx.system_properties.get("gpu_id")) + self.loca_world_size+1)) options = rpc.TensorPipeRpcBackendOptions( num_worker_threads, From ea58747f611bc1ebb8d5baacadbbdc5f37556206 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 15 Mar 2023 19:40:21 -0700 Subject: [PATCH 27/47] get gpuid by rotating model's deviceIds --- .../java/org/pytorch/serve/wlm/Model.java | 23 +++++++++++------- .../pytorch/serve/wlm/WorkLoadManager.java | 24 +++++++++---------- ts/service.py | 6 ++++- ts/utils/util.py | 2 +- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index 46333efe29..cacfa09c2b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -46,11 +46,12 @@ public class Model { ConfigManager.getInstance().getNumberOfGpu() > 0 ? ModelConfig.DeviceType.GPU : ModelConfig.DeviceType.CPU; - private List coreIds; + private List deviceIds; private int numCores; private ReentrantLock lock; private int responseTimeout; private ModelVersionName modelVersionName; + private AtomicInteger gpuCounter = new AtomicInteger(0); private boolean isWorkflowModel; @@ -76,7 +77,7 @@ public Model(ModelArchive modelArchive, int queueSize) { ? ModelConfig.DeviceType.GPU : deviceType; } - coreIds = modelArchive.getModelConfig().getDeviceIds(); + deviceIds = modelArchive.getModelConfig().getDeviceIds(); } else { batchSize = 1; maxBatchDelay = 100; @@ -84,8 +85,10 @@ public Model(ModelArchive modelArchive, int queueSize) { if (ConfigManager.getInstance().getNumberOfGpu() > 0) { numCores = - (coreIds != null && coreIds.size() > 0) - ? coreIds.size() + (deviceType == ModelConfig.DeviceType.GPU + && deviceIds != null + && deviceIds.size() > 0) + ? deviceIds.size() : ConfigManager.getInstance().getNumberOfGpu(); } @@ -291,12 +294,12 @@ public void setResponseTimeout(int responseTimeout) { this.responseTimeout = responseTimeout; } - public List getCoreIds() { - return this.coreIds; + public List getDeviceIds() { + return this.deviceIds; } - public void setCoreIdsIds(List coreIds) { - Collections.copy(this.coreIds, coreIds); + public void setDeviceIds(List deviceIds) { + Collections.copy(this.deviceIds, deviceIds); } public int getParallelLevel() { @@ -318,4 +321,8 @@ public ModelConfig.DeviceType getDeviceType() { public int getNumCores() { return this.numCores; } + + public AtomicInteger getGpuCounter() { + return gpuCounter; + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index d48d85bc5b..c8915143b7 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -14,7 +14,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import org.pytorch.serve.archive.model.ModelConfig; import org.pytorch.serve.snapshot.SnapshotManager; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.OSUtils; @@ -193,22 +192,21 @@ private void addThreads( List threads, Model model, int count, CompletableFuture future) { WorkerStateListener listener = new WorkerStateListener(future, count); int maxGpu = model.getNumCores(); - int parallelGpuIdx = 0; for (int i = 0; i < count; ++i) { int gpuId = -1; - if (maxGpu > 0 && model.getDeviceType() == ModelConfig.DeviceType.GPU) { - if (model.getParallelLevel() > 1) { - gpuId = parallelGpuIdx; - parallelGpuIdx += model.getParallelLevel(); + if (maxGpu > 0) { + if (model.getDeviceIds() != null && model.getDeviceIds().size() > 0) { + gpuId = + model.getGpuCounter() + .getAndAccumulate( + maxGpu, + (prev, maxGpuId) -> + (prev + model.getParallelLevel()) % maxGpuId); } else { - if (model.getCoreIds() != null) { - gpuId = model.getCoreIds().get(parallelGpuIdx++ % maxGpu); - } else { - gpuId = - gpuCounter.accumulateAndGet( - maxGpu, (prev, maxGpuId) -> ++prev % maxGpuId); - } + gpuId = + gpuCounter.accumulateAndGet( + maxGpu, (prev, maxGpuId) -> ++prev % maxGpuId); } } diff --git a/ts/service.py b/ts/service.py index 724e7b2bb9..25bbd39b8f 100644 --- a/ts/service.py +++ b/ts/service.py @@ -32,13 +32,17 @@ def __init__( limit_max_image_pixels=True, metrics_cache=None, ): - model_yaml_config = None + model_yaml_config = dict() if manifest is not None and "model" in manifest: model = manifest["model"] if "configFile" in model: model_yaml_config_file = model["configFile"] model_yaml_config = get_yaml_config(os.path.join(model_dir, model_yaml_config_file)) + if "deviceIds" in model_yaml_config and "parallelLevel" in model_yaml_config: + if int(model_yaml_config["parallelLevel"]) == 1: + gpu = model_yaml_config["deviceIds"][gpu] + self._context = Context( model_name, model_dir, diff --git a/ts/utils/util.py b/ts/utils/util.py index 0a5f914a95..208e3c199d 100644 --- a/ts/utils/util.py +++ b/ts/utils/util.py @@ -136,7 +136,7 @@ def map_class_to_label(probs, mapping=None, lbl_classes=None): return results def get_yaml_config(yaml_file_path): - config = None + config = dict() with open(yaml_file_path, 'r') as file: config = yaml.safe_load(file) return config From fdf8d5174f07704add9f1b98814831698fdeeaa6 Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 20 Mar 2023 19:27:30 -0700 Subject: [PATCH 28/47] code clean up --- .../org/pytorch/serve/wlm/WorkerLifeCycle.java | 4 +--- ts/model_service_worker.py | 16 ++++++---------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index 87d5b1d4f0..e77e99290b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -98,9 +98,7 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup ArrayList argl = new ArrayList<>(); if (model.getParallelLevel() > 1) { attachRunner(argl, port); - } - - if (model.getParallelLevel() == 1) { + } else if (model.getParallelLevel() == 1) { argl.add(EnvironmentUtils.getPythonRunTime(model)); } diff --git a/ts/model_service_worker.py b/ts/model_service_worker.py index 03b7ee792c..0afca19f12 100644 --- a/ts/model_service_worker.py +++ b/ts/model_service_worker.py @@ -17,13 +17,14 @@ from ts.protocol.otf_message_handler import create_load_model_response, retrieve_msg MAX_FAILURE_THRESHOLD = 5 -SOCKET_ACCEPT_TIMEOUT = 300.0 +SOCKET_ACCEPT_TIMEOUT = 30.0 DEBUG = False BENCHMARK = os.getenv("TS_BENCHMARK") BENCHMARK = BENCHMARK in ["True", "true", "TRUE"] LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0)) WORLD_SIZE = int(os.getenv('WORLD_SIZE', 0)) WORLD_RANK = int(os.getenv('RANK', 0)) +LOCAL_WORLD_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", 0)) class TorchModelServiceWorker(object): """ @@ -44,8 +45,8 @@ def __init__( if s_name is None: raise ValueError("Wrong arguments passed. No socket name given.") s_name_parts = s_name.rsplit('.', 1) - logging.info(f"part0={s_name_parts[0]}, part1={s_name_parts[1]}, pid={str(os.getpid())}") - s_name_new = s_name_parts[0] + '.' + str(int(s_name_parts[1]) + WORLD_RANK) + logging.info("s_name_part0=%s, s_name_part1=%s, pid=%d", s_name_parts[0], s_name_parts[1], os.getpid()) + s_name_new = s_name_parts[0] + '.' + str(int(s_name_parts[1]) + LOCAL_RANK) self.sock_name, self.port = s_name_new, -1 try: os.remove(s_name_new) @@ -59,7 +60,7 @@ def __init__( self.sock_name = host_addr if host_addr is not None else "127.0.0.1" if port_num is None: raise ValueError("Wrong arguments passed. No socket port given.") - self.port = port_num + self.port = port_num + LOCAL_RANK else: raise ValueError("Incomplete data provided") @@ -179,8 +180,7 @@ def run_server(self): else: self.sock.bind((self.sock_name, int(self.port))) - # self.sock.listen(1) - self.sock.listen(128) + self.sock.listen(1) logging.info("[PID]%d", os.getpid()) logging.info("Torch worker started.") @@ -214,8 +214,6 @@ def run_server(self): port = args.port metrics_config = args.metrics_config - print(f"LOCAL_RANK={str(LOCAL_RANK)}, WORLD_SIZE={str(WORLD_SIZE)}, WORLD_RANK={str(WORLD_RANK)}") - if BENCHMARK: import cProfile @@ -226,9 +224,7 @@ def run_server(self): worker = TorchModelServiceWorker( sock_type, socket_name, host, port, metrics_config ) - worker.run_server() - if BENCHMARK: pr.disable() pr.dump_stats("/tmp/tsPythonProfile.prof") From 2b978b622ca057262808423e40a9ee5777f8775d Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 21 Mar 2023 21:23:39 -0700 Subject: [PATCH 29/47] update config yaml --- examples/Huggingface_Largemodels/model-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/Huggingface_Largemodels/model-config.yaml b/examples/Huggingface_Largemodels/model-config.yaml index 29c1314de0..65c680a3ec 100644 --- a/examples/Huggingface_Largemodels/model-config.yaml +++ b/examples/Huggingface_Largemodels/model-config.yaml @@ -5,7 +5,7 @@ batchSize: 1 maxBatchDelay: 100 responseTimeout: 120 deviceType: cpu # cpu, gpu, neuron -deviceIds: [0,1,2.3] # device index for gpu, neuron +deviceIds: [0,1,2,3] # device index for gpu, neuron parallelLevel: 4 # rpc world size parallelType: pp # pp: pipeline parallel; pptp: tensor+pipeline parallel From ab43251a1e445d327735219da30f5634c9965d74 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 21 Mar 2023 22:45:45 -0700 Subject: [PATCH 30/47] update gradle --- frontend/archive/build.gradle | 3 --- frontend/gradle.properties | 3 +-- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/frontend/archive/build.gradle b/frontend/archive/build.gradle index 234238590d..3f301d1b86 100644 --- a/frontend/archive/build.gradle +++ b/frontend/archive/build.gradle @@ -5,9 +5,6 @@ dependencies { api "com.google.code.gson:gson:${gson_version}" implementation "org.yaml:snakeyaml:${snakeyaml_version}" - compileOnly "org.projectlombok:lombok:${lombok_version}" - annotationProcessor "org.projectlombok:lombok:${lombok_version}" - testImplementation "commons-cli:commons-cli:${commons_cli_version}" testImplementation "org.testng:testng:${testng_version}" } diff --git a/frontend/gradle.properties b/frontend/gradle.properties index 42cd7c29ac..91ec64fe6e 100644 --- a/frontend/gradle.properties +++ b/frontend/gradle.properties @@ -11,5 +11,4 @@ torchserve_sdk_version=0.0.4 snakeyaml_version=1.31 grpc_version=1.50.0 protoc_version=3.18.0 -lmax_disruptor_version=3.4.4 -lombok_version=1.18.26 \ No newline at end of file +lmax_disruptor_version=3.4.4 \ No newline at end of file From 60441bcc1136f3262c72a96eab71a46eb0d446f5 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 22 Mar 2023 15:39:56 -0700 Subject: [PATCH 31/47] update docs --- docs/configuration.md | 5 ++++ .../java/org/pytorch/serve/ModelServer.java | 3 +-- .../serve/grpcimpl/ManagementImpl.java | 16 ++++++------- .../pytorch/serve/wlm/WorkLoadManager.java | 2 +- model-archiver/README.md | 23 ++++++++++++------- .../test_example_near_real_time_video.py | 1 + .../test_example_scriptable_tokenzier.py | 1 + test/pytest/test_example_torchrec_dlrm.py | 1 + 8 files changed, 33 insertions(+), 19 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 8177b88687..f7e8139d47 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -266,6 +266,11 @@ models={\ }\ } ``` +As of version 0.8.0, TorchServe allows for model configuration using a YAML file embedded in the MAR file. This YAML file contains two distinct parts that determine how a model is configured: frontend parameters and backend parameters. (see [details](../model-archiver/README.md)) + +* The frontend parameters are controlled by TorchServe's frontend and specify the parameter name and default values. TorchServe now uses a priority order to determine the final value of a model's parameters in frontend. Specifically, the config.property file has the lowest priority, followed by the model configuration YAML file, and finally, the REST or gRPC model management API has the highest priority. + +* The backend parameters are fully controlled by the user. Users customized handler can access the backend parameters via the `model_yaml_config` property of the context object. ### Other properties diff --git a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java index 9a284febc2..627d27217e 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java @@ -58,8 +58,6 @@ import org.slf4j.LoggerFactory; public class ModelServer { - - public static final int MAX_RCVBUF_SIZE = 4096; private Logger logger = LoggerFactory.getLogger(ModelServer.class); private ServerGroups serverGroups; private Server inferencegRPCServer; @@ -67,6 +65,7 @@ public class ModelServer { private List futures = new ArrayList<>(2); private AtomicBoolean stopped = new AtomicBoolean(false); private ConfigManager configManager; + public static final int MAX_RCVBUF_SIZE = 4096; /** Creates a new {@code ModelServer} instance. */ public ModelServer(ConfigManager configManager) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java index 3ad18221dc..f254729b13 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java +++ b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java @@ -34,14 +34,6 @@ public class ManagementImpl extends ManagementAPIsServiceImplBase { private static final Logger logger = LoggerFactory.getLogger(ManagementImpl.class); - public static void sendErrorResponse( - StreamObserver responseObserver, Status status, Exception e) { - responseObserver.onError( - status.withDescription(e.getMessage()) - .augmentDescription(e.getClass().getCanonicalName()) - .asRuntimeException()); - } - @Override public void describeModel( DescribeModelRequest request, StreamObserver responseObserver) { @@ -239,6 +231,14 @@ private void sendResponse(StreamObserver responseObserver, S responseObserver.onCompleted(); } + public static void sendErrorResponse( + StreamObserver responseObserver, Status status, Exception e) { + responseObserver.onError( + status.withDescription(e.getMessage()) + .augmentDescription(e.getClass().getCanonicalName()) + .asRuntimeException()); + } + private void sendErrorResponse( StreamObserver responseObserver, Status status, diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index c8915143b7..27ea12f249 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -145,7 +145,7 @@ public CompletableFuture modelChanged( // Need to check worker process here since thread.shutdown() -> lifecycle.exit() // -> This may nullify process object per destroyForcibly doc. - if ((workerProcess != null) && workerProcess.isAlive()) { + if (workerProcess != null && workerProcess.isAlive()) { boolean workerDestroyed = false; try { String cmd = String.format(OSUtils.getKillCmd(), workerProcess.pid()); diff --git a/model-archiver/README.md b/model-archiver/README.md index 596d97b63e..25f53aa95e 100644 --- a/model-archiver/README.md +++ b/model-archiver/README.md @@ -158,14 +158,21 @@ For more details refer [default handler documentation](../docs/default_handlers. A model config yaml file. For example: ``` -minWorkers: 1 -maxWorkers: 1 -maxBatchDelay: 100 -responseTimeout: 120 -parallelLevel: 4 -parallelMode: "pp" # pp: pipeline parallel; tp: tensor parallel; pptp: pipeline+tensor parallel -deviceType: "gpu" #cpu, gpu, neuron -deviceIds: [0,1,2,3] +# TS Frontend parameters +minWorkers: 1 # default: #CPU or #GPU +maxWorkers: 1 # default: #CPU or #GPU +batchSize: 1 # default: 1 +maxBatchDelay: 100 # default: 100 msec +responseTimeout: 120 # default: 120 sec +deviceType: cpu # cpu, gpu, neuron +deviceIds: [0,1,2,3] # device index for gpu, neuron. Default: all visible devices +parallelLevel: 4 # rpc world size. Default: 1 +parallelType: pp # pp: pipeline parallel; pptp: tensor+pipeline parallel. Default: empty + +# TS backend parameters +pippy: + rpc_timeout: 1800 + pp_group_size: 4 # pipeline parallel size, tp_group_size = world size / pp_group_size ``` ## Creating a Model Archive diff --git a/test/pytest/test_example_near_real_time_video.py b/test/pytest/test_example_near_real_time_video.py index 7b48c3147b..093c9f97d7 100644 --- a/test/pytest/test_example_near_real_time_video.py +++ b/test/pytest/test_example_near_real_time_video.py @@ -61,6 +61,7 @@ def create_mar_file(work_dir, session_mocker, model_archiver): ).as_posix(), export_path=work_dir, requirements_file=None, + config_file=None, runtime="python", force=False, archive_format="default", diff --git a/test/pytest/test_example_scriptable_tokenzier.py b/test/pytest/test_example_scriptable_tokenzier.py index c16b4cf364..fcb4dce0cc 100644 --- a/test/pytest/test_example_scriptable_tokenzier.py +++ b/test/pytest/test_example_scriptable_tokenzier.py @@ -157,6 +157,7 @@ def create_mar_file(work_dir, session_mocker, jit_file_path, model_archiver): extra_files=os.path.join(EXAMPLE_ROOT_DIR, "index_to_name.json"), export_path=work_dir, requirements_file=None, + config_file=None, runtime="python", force=False, archive_format="default", diff --git a/test/pytest/test_example_torchrec_dlrm.py b/test/pytest/test_example_torchrec_dlrm.py index c0808fc5d7..505f4ecf9c 100644 --- a/test/pytest/test_example_torchrec_dlrm.py +++ b/test/pytest/test_example_torchrec_dlrm.py @@ -101,6 +101,7 @@ def create_mar_file(work_dir, session_mocker, serialized_file, model_archiver): + EXAMPLE_ROOT_DIR.joinpath("dlrm_model_config.py").as_posix(), export_path=work_dir, requirements_file=None, + config_file=None, runtime="python", force=False, archive_format="default", From 5c956820995d344b7dc13721584c178e5353bc66 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 22 Mar 2023 15:53:55 -0700 Subject: [PATCH 32/47] clean up code --- .../Huggingface_Largemodels/model-config.yaml | 21 - .../pippy_pp_handler.py | 309 -------------- .../pippy_pptp_handler.py | 376 ------------------ 3 files changed, 706 deletions(-) delete mode 100644 examples/Huggingface_Largemodels/model-config.yaml delete mode 100644 examples/Huggingface_Largemodels/pippy_pp_handler.py delete mode 100644 examples/Huggingface_Largemodels/pippy_pptp_handler.py diff --git a/examples/Huggingface_Largemodels/model-config.yaml b/examples/Huggingface_Largemodels/model-config.yaml deleted file mode 100644 index 65c680a3ec..0000000000 --- a/examples/Huggingface_Largemodels/model-config.yaml +++ /dev/null @@ -1,21 +0,0 @@ -# TS Frontend parameters -minWorkers: 1 -maxWorkers: 1 -batchSize: 1 -maxBatchDelay: 100 -responseTimeout: 120 -deviceType: cpu # cpu, gpu, neuron -deviceIds: [0,1,2,3] # device index for gpu, neuron -parallelLevel: 4 # rpc world size -parallelType: pp # pp: pipeline parallel; pptp: tensor+pipeline parallel - -torchrun: - max_restarts: 3 - - -# TS backend parameters -pippy: - rpc_timeout: 1800 - pp_group_size: 4 # pipeline parallel size, tp_group_size = world size / pp_group_size - - diff --git a/examples/Huggingface_Largemodels/pippy_pp_handler.py b/examples/Huggingface_Largemodels/pippy_pp_handler.py deleted file mode 100644 index b726e4a8a8..0000000000 --- a/examples/Huggingface_Largemodels/pippy_pp_handler.py +++ /dev/null @@ -1,309 +0,0 @@ -import json -import logging -import os -import zipfile -from abc import ABC - -import argparse -import inspect -import logging -import os -import time - -import pippy.fx -#from pippy import run_pippy -from pippy.IR import MultiUseParameterConfig, Pipe -from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverInterleaved1F1B, \ - PipelineDriverBase -from pippy.hf import PiPPyHFTracer -from pippy.microbatch import TensorChunkSpec -from pippy import split_on_size_threshold, split_into_equal_size -from transformers import AutoModelForSeq2SeqLM -from transformers import OPTModel, BloomModel -from PIL import Image -import requests -from transformers import AutoFeatureExtractor, RegNetModel -from transformers import OPTForCausalLM -import torch.distributed.rpc as rpc - -import torch -import transformers -from transformers import BloomForCausalLM, BloomTokenizerFast - -from ts.torch_handler.base_handler import BaseHandler - -logger = logging.getLogger(__name__) -logger.info("Transformers version %s", transformers.__version__) - -PIPPY_VERBOSITY = os.environ.get("PIPPY_VERBOSITY", "DEBUG") -TORCH_DTYPES = { - "float16": torch.float16, - "float32": torch.float32, - "float64": torch.float64, -} - -schedules = { - 'FillDrain': PipelineDriverFillDrain, - '1F1B': PipelineDriver1F1B, - 'Interleaved1F1B': PipelineDriverInterleaved1F1B, -} - -class TransformersSeqClassifierHandler(BaseHandler, ABC): - """ - Transformers handler class for sequence, token classification and question answering. - """ - - def __init__(self): - super(TransformersSeqClassifierHandler, self).__init__() - self.initialized = False - self.local_rank = int(os.environ["LOCAL_RANK"]) - self.world_size = int(os.environ["WORLD_SIZE"]) - - options = rpc.TensorPipeRpcBackendOptions( - num_worker_threads=512, - rpc_timeout=1800 - # transports=None, - ) - - - # if args.cuda: - n_devs = torch.cuda.device_count() - print(f"n_devs={n_devs}") - dev_id = self.local_rank % n_devs - for i in range (self.world_size): - print(f"worker{i}, {dev_id}: {i % n_devs}") - options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) - - self.device = f"cuda:{dev_id}" - print( - f"rank = {self.local_rank} pid/device = " - f"{os.getpid()}/{self.device}" - ) - - rpc.init_rpc(f"worker{self.local_rank}", - rank=self.local_rank, - world_size=self.world_size, - rpc_backend_options=options) - - def initialize(self, ctx): - """In this initialize function, the BERT model is loaded and - the Layer Integrated Gradients Algorithm for Captum Explanations - is initialized here. - Args: - ctx (context): It is a JSON Object containing information - pertaining to the model artefacts parameters. - """ - # parser = argparse.ArgumentParser() - # args = parser.parse_args() - # args.world_size = 4 - # args.gspmd = 1 - #if self.local_rank != 0: - # return - - self.manifest = ctx.manifest - properties = ctx.system_properties - model_dir = properties.get("model_dir") - - # self.device = torch.device( - # "cuda:" + str(properties.get("gpu_id")) - # if torch.cuda.is_available() and properties.get("gpu_id") is not None - # else "cpu" - # ) - # Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode - # further setup config can be added. - with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref: - zip_ref.extractall(model_dir + "/model") - - # read configs for the mode, model_name, etc. from setup_config.json - setup_config_path = os.path.join(model_dir, "setup_config.json") - if os.path.isfile(setup_config_path): - with open(setup_config_path) as setup_config_file: - self.setup_config = json.load(setup_config_file) - else: - logger.warning("Missing the setup_config.json file.") - - torch.manual_seed(42) - replicate = 0 - schedule = list(schedules.keys())[0] - MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE if replicate else MultiUseParameterConfig.TRANSMIT - print(f'REPLICATE config: {replicate} -> {MULTI_USE_PARAM_CONFIG}') - print("Using schedule:", schedule) - - model = BloomModel.from_pretrained( - model_dir + "/model", use_cache=False) - - self.tokenizer = BloomTokenizerFast.from_pretrained( - model_dir + "/model", return_tensors="pt" - ) - - logger.info("********************* model loaded *************************", model_dir) - - # model = BloomModel.from_pretrained("bigscience/bloom-3b", use_cache=False) - - model_config = model.config - - model_config.use_cache = False # don't output `past_key_values` - model.eval() - - - split_policy = split_into_equal_size(self.world_size) - pp_ranks = [0,1,2,3] - all_worker_ranks = list(range(self.world_size)) - chunks = 1 - bs = 1 * chunks - seq_length = 16 - - - input_names = ['input_ids'] - sig = inspect.signature(model.forward) - concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} - - print('Instantiating model Pipeline') - model_init_start = time.time() - pipe_driver, stage_mode = pippy.all_compile( - model, - num_ranks=self.world_size, - num_chunks=chunks, - schedule="FillDrain", - split_policy=split_policy, - tracer=PiPPyHFTracer(), - concrete_args=concrete_args, - ) - # model_pipe = Pipe.from_tracing(self.model, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, - # output_loss_value_spec=None, split_policy=split_policy - # ) - - # model_pipe.defer_stage_init(self.device + self.local_rank) - - # pippy.utils.pp_group_barrier() - - # split_gm_children = list(model_pipe.split_gm.children()) - - # pipe_driver: PipelineDriverBase = schedules["FillDrain"](model_pipe, chunks, - # world_size=self.world_size, - # all_ranks=all_worker_ranks, - # ) - - self.model = pipe_driver - logger.info("Transformer model from path %s loaded successfully", model_dir) - - self.initialized = True - - - def preprocess(self, requests): - """Basic text preprocessing, based on the user's chocie of application mode. - Args: - requests (str): The Input data in the form of text is passed on to the preprocess - function. - Returns: - list : The preprocess function returns a list of Tensor for the size of the word tokens. - """ - input_ids_batch = None - attention_mask_batch = None - for idx, data in enumerate(requests): - input_text = data.get("data") - if input_text is None: - input_text = data.get("body") - if isinstance(input_text, (bytes, bytearray)): - input_text = input_text.decode("utf-8") - - max_length = self.setup_config["max_length"] - logger.info("Received text: '%s'", input_text) - - inputs = self.tokenizer.encode_plus( - input_text, - max_length=int(max_length), - pad_to_max_length=True, - add_special_tokens=True, - return_tensors="pt", - ) - - input_ids = inputs["input_ids"].to(self.device) - attention_mask = inputs["attention_mask"].to(self.device) - # making a batch out of the recieved requests - # attention masks are passed for cases where input tokens are padded. - if input_ids.shape is not None: - if input_ids_batch is None: - input_ids_batch = input_ids - attention_mask_batch = attention_mask - else: - input_ids_batch = torch.cat((input_ids_batch, input_ids), 0) - attention_mask_batch = torch.cat( - (attention_mask_batch, attention_mask), 0 - ) - return (input_ids_batch, attention_mask_batch) - - def inference(self, input_batch): - """Predict the class (or classes) of the received text using the - serialized transformers checkpoint. - Args: - input_batch (list): List of Text Tensors from the pre-process function is passed here - Returns: - list : It returns a list of the predicted value for the input text - """ - (input_ids_batch, _) = input_batch - inferences = [] - input_ids_batch = input_ids_batch.to(self.device) - model_input_dict = {} - model_input_dict["input_ids"]=input_ids_batch - # outputs = self.model.generate( - # input_ids_batch, do_sample=True, max_length=50, top_p=0.95, top_k=60 - # ) - # for i, _ in enumerate(outputs): - # inferences.append( - # self.tokenizer.decode(outputs[i], skip_special_tokens=True) - # ) - if self.local_rank==0: - output = self.model(**model_input_dict) - # rpc.shutdown() - print("************** here is the output",type(output)) - logger.info("Generated text: '%s'", inferences) - inferences.append(output) - print("Generated text", inferences) - return inferences - - def postprocess(self, inference_output): - """Post Process Function converts the predicted response into Torchserve readable format. - Args: - inference_output (list): It contains the predicted response of the input text. - Returns: - (list): Returns a list of the Predictions and Explanations. - """ - return inference_output - - def handle(self, data, context): - if self.local_rank != 0: - pass - start_time = time.time() - - self.context = context - metrics = self.context.metrics - - #run_pippy(self.initialize, context) - - is_profiler_enabled = os.environ.get("ENABLE_TORCH_PROFILER", None) - if is_profiler_enabled: - if PROFILER_AVAILABLE: - output, _ = self._infer_with_profiler(data=data) - else: - raise RuntimeError( - "Profiler is enabled but current version of torch does not support." - "Install torch>=1.8.1 to use profiler." - ) - else: - if self._is_describe(): - output = [self.describe_handle()] - else: - data_preprocess = self.preprocess(data) - - if not self._is_explain(): - output = self.inference(data_preprocess) - output = self.postprocess(output) - else: - output = self.explain_handle(data_preprocess, data) - - stop_time = time.time() - metrics.add_time( - "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" - ) - return output diff --git a/examples/Huggingface_Largemodels/pippy_pptp_handler.py b/examples/Huggingface_Largemodels/pippy_pptp_handler.py deleted file mode 100644 index 7776aee429..0000000000 --- a/examples/Huggingface_Largemodels/pippy_pptp_handler.py +++ /dev/null @@ -1,376 +0,0 @@ -import json -import zipfile -from abc import ABC - -import inspect -import logging -import os -import time - -import torch -import pippy -import pippy.fx -from torch.distributed._tensor import ( - DeviceMesh, -) -from torch.distributed.tensor.parallel import ( - PairwiseParallel, - parallelize_module, -) - -from pippy.IR import MultiUseParameterConfig, Pipe -from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverInterleaved1F1B, \ - PipelineDriverBase -from pippy import split_on_size_threshold, split_into_equal_size -from transformers import OPTModel, BloomModel -import torch.distributed.rpc as rpc - -import transformers -from transformers import BloomForCausalLM, BloomTokenizerFast -from ts.torch_handler.base_handler import BaseHandler - -logger = logging.getLogger(__name__) -logger.info("Transformers version %s", transformers.__version__) - -PIPPY_VERBOSITY = os.environ.get("PIPPY_VERBOSITY", "DEBUG") -TORCH_DTYPES = { - "float16": torch.float16, - "float32": torch.float32, - "float64": torch.float64, -} - -schedules = { - 'FillDrain': PipelineDriverFillDrain, - '1F1B': PipelineDriver1F1B, - 'Interleaved1F1B': PipelineDriverInterleaved1F1B, -} - -class TransformersSeqClassifierHandler(BaseHandler, ABC): - """ - Transformers handler class for sequence, token classification and question answering. - """ - - def __init__(self): - super(TransformersSeqClassifierHandler, self).__init__() - self.initialized = False - self.local_rank = int(os.environ["LOCAL_RANK"]) - self.rank = int(os.environ["RANK"]) - self.world_size = int(os.environ["WORLD_SIZE"]) - self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - self.pp_rank = 0 - self.pp_ranks = None - - def initialize(self, ctx): - """In this initialize function, the BERT model is loaded and - the Layer Integrated Gradients Algorithm for Captum Explanations - is initialized here. - Args: - ctx (context): It is a JSON Object containing information - pertaining to the model artefacts parameters. - """ - pp_group_size = self.world_size - num_worker_threads = 512 - rpc_timeout = 1800 - if ctx.model_yaml_config is not None: - if ctx.system_properties.get("gpu_id") != -1 \ - and "deviceIds" in ctx.model_yaml_config: - device_ids = ','.join(str(e) for e in ctx.model_yaml_config["deviceIds"][int(ctx.system_properties.get("gpu_id")):int(ctx.system_properties.get("gpu_id"))+self.loca_world_size+1]) - os.environ["CUDA_VISIBLE_DEVICE"] = device_ids - - if "pippy" in ctx.model_yaml_config: - if "pp_group_size" in ctx.model_yaml_config["pippy"] \ - and self.world_size % int(ctx.model_yaml_config["pippy"]["pp_group_size"]) == 0: - pp_group_size = int(ctx.model_yaml_config["pippy"]["pp_group_size"]) - - if "num_worker_threads" in ctx.model_yaml_config["pippy"]: - num_worker_threads = int(ctx.model_yaml_config["pippy"]["num_worker_threads"]) - - if "rpc_timeout" in ctx.model_yaml_config["pippy"]: - rpc_timeout = int(ctx.model_yaml_config["pippy"]["rpc_timeout"]) - - if ctx.system_properties.get("gpu_id") != -1 and os.environ["CUDA_VISIBLE_DEVICE"] is None: - os.environ["CUDA_VISIBLE_DEVICE"] = ','.join(str(e) for e in range(int(ctx.system_properties.get("gpu_id")), int(ctx.system_properties.get("gpu_id")) + self.loca_world_size+1)) - - options = rpc.TensorPipeRpcBackendOptions( - num_worker_threads, - rpc_timeout - ) - device_type = "cpu" - if int(ctx.system_properties.get("gpu_id")) != -1: - device_type = "cuda" - n_devs = torch.cuda.device_count() - dev_id = self.local_rank % n_devs - for i in range(self.world_size): - logging.info(f"worker{i}, {dev_id}: {i % n_devs}") - options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) - - self.device = f"cuda:{dev_id}" - logging.info( - f"rank = {self.local_rank} pid/device = " - f"{os.getpid()}/{self.device}" - ) - else: - self.device = "cpu" - rpc.init_rpc(f"worker{self.rank}", - rank=self.rank, - world_size=self.world_size, - rpc_backend_options=options) - - tp_group_size = self.world_size // pp_group_size - dp_group_size = self.world_size // pp_group_size - - logging.info( - f"[PiPPy] World size: {self.world_size}, " - f"DP group size: {dp_group_size}, " - f"PP group size: {pp_group_size}" - ) - pp_ranks_per_dp_group = [ - [i + rank for i in range(pp_group_size)] - for rank in range(dp_group_size) - ] - self.pp_ranks = pp_ranks_per_dp_group[self.rank % dp_group_size] - self.pp_rank = self.rank // tp_group_size - logging.info(f"Global rank {self.rank}, pipeline: {self.pp_ranks}, my rank in pipe: {self.pp_rank}") - - d_hid = 256 - batch_size_per_chunk = 8 - chunks = pp_group_size - #inp_size = [chunks * batch_size_per_chunk, d_hid] - # Ensure all tp ranks have same input. - #torch.manual_seed(0) - #inp = torch.rand(*inp_size, device=device_type) - - self.manifest = ctx.manifest - properties = ctx.system_properties - model_dir = properties.get("model_dir") - - # self.device = torch.device( - # "cuda:" + str(properties.get("gpu_id")) - # if torch.cuda.is_available() and properties.get("gpu_id") is not None - # else "cpu" - # ) - # Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode - # further setup config can be added. - with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref: - zip_ref.extractall(model_dir + "/model") - - # read configs for the mode, model_name, etc. from setup_config.json - setup_config_path = os.path.join(model_dir, "setup_config.json") - if os.path.isfile(setup_config_path): - with open(setup_config_path) as setup_config_file: - self.setup_config = json.load(setup_config_file) - else: - logger.warning("Missing the setup_config.json file.") - - - replicate = 0 - schedule = list(schedules.keys())[0] - MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE if replicate else MultiUseParameterConfig.TRANSMIT - print(f'REPLICATE config: {replicate} -> {MULTI_USE_PARAM_CONFIG}') - print("Using schedule:", schedule) - - model = BloomModel.from_pretrained( - model_dir + "/model", use_cache=False) - - self.tokenizer = BloomTokenizerFast.from_pretrained( - model_dir + "/model", return_tensors="pt" - ) - - logger.info("********************* model loaded *************************", model_dir) - - # model = BloomModel.from_pretrained("bigscience/bloom-3b", use_cache=False) - - model_config = model.config - - model_config.use_cache = False # don't output `past_key_values` - model.eval() - - - split_policy = split_into_equal_size(self.world_size) - - - input_names = ['input_ids'] - sig = inspect.signature(model.forward) - concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} - - torch.manual_seed(0) - ec_tp = model(d_hid) - ec_tp.to(self.device) - start_idx = 0 - device_mesh = DeviceMesh( - device_type, - list(range(start_idx, start_idx + tp_group_size)), - ) - logging.info(f"Rank {self.rank} calling parallelize_module with {device_mesh}") - parallelize_module(ec_tp, device_mesh, PairwiseParallel()) - logging.info(f"Rank {self.rank} sharding complete") - - print('Instantiating model Pipeline') - model_init_start = time.time() - # Get: - # - pipeline driver (for pipeline head rank) - # - stage submodule (for all ranks) - pipe_driver, submod = pippy.all_compile( - model, - pp_group_size, - chunks, - ranks=self.pp_ranks, - ) - - # Create TP device mesh - my_device_mesh = None - for stage in range(pp_group_size): - start_rank = stage * tp_group_size - tp_ranks = list(range(start_rank, start_rank + tp_group_size)) - tp_device_mesh = DeviceMesh( - device_type, - tp_ranks, - ) - if stage == self.pp_rank: - my_device_mesh = tp_device_mesh - - # Tensor parallelize submodules - print(f"Rank {self.rank} calling parallelize_module with {my_device_mesh}") - parallelize_module(submod, my_device_mesh, PairwiseParallel()) - - - # model_pipe = Pipe.from_tracing(self.model, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, - # output_loss_value_spec=None, split_policy=split_policy - # ) - - # model_pipe.defer_stage_init(self.device + self.local_rank) - - # pippy.utils.pp_group_barrier() - - # split_gm_children = list(model_pipe.split_gm.children()) - - # pipe_driver: PipelineDriverBase = schedules["FillDrain"](model_pipe, chunks, - # world_size=self.world_size, - # all_ranks=all_worker_ranks, - # ) - - self.model = pipe_driver - logger.info("Transformer model from path %s loaded successfully", model_dir) - - self.initialized = True - - - def preprocess(self, requests): - """Basic text preprocessing, based on the user's chocie of application mode. - Args: - requests (str): The Input data in the form of text is passed on to the preprocess - function. - Returns: - list : The preprocess function returns a list of Tensor for the size of the word tokens. - """ - input_ids_batch = None - attention_mask_batch = None - for idx, data in enumerate(requests): - input_text = data.get("data") - if input_text is None: - input_text = data.get("body") - if isinstance(input_text, (bytes, bytearray)): - input_text = input_text.decode("utf-8") - - max_length = self.setup_config["max_length"] - logger.info("Received text: '%s'", input_text) - - inputs = self.tokenizer.encode_plus( - input_text, - max_length=int(max_length), - pad_to_max_length=True, - add_special_tokens=True, - return_tensors="pt", - ) - - input_ids = inputs["input_ids"].to(self.device) - attention_mask = inputs["attention_mask"].to(self.device) - # making a batch out of the recieved requests - # attention masks are passed for cases where input tokens are padded. - if input_ids.shape is not None: - if input_ids_batch is None: - input_ids_batch = input_ids - attention_mask_batch = attention_mask - else: - input_ids_batch = torch.cat((input_ids_batch, input_ids), 0) - attention_mask_batch = torch.cat( - (attention_mask_batch, attention_mask), 0 - ) - return (input_ids_batch, attention_mask_batch) - - def inference(self, input_batch): - """Predict the class (or classes) of the received text using the - serialized transformers checkpoint. - Args: - input_batch (list): List of Text Tensors from the pre-process function is passed here - Returns: - list : It returns a list of the predicted value for the input text - """ - (input_ids_batch, _) = input_batch - inferences = [] - input_ids_batch = input_ids_batch.to(self.device) - model_input_dict = {} - model_input_dict["input_ids"]=input_ids_batch - # outputs = self.model.generate( - # input_ids_batch, do_sample=True, max_length=50, top_p=0.95, top_k=60 - # ) - # for i, _ in enumerate(outputs): - # inferences.append( - # self.tokenizer.decode(outputs[i], skip_special_tokens=True) - # ) - #if self.pp_rank == 0: - # print(f"Rank {self.rank} Instantiated pipeline with ranks {self.pp_ranks}") - output = self.model(**model_input_dict) - - - print("************** here is the output",type(output)) - logger.info("Generated text: '%s'", inferences) - inferences.append(output) - print("Generated text", inferences) - return inferences - - def postprocess(self, inference_output): - """Post Process Function converts the predicted response into Torchserve readable format. - Args: - inference_output (list): It contains the predicted response of the input text. - Returns: - (list): Returns a list of the Predictions and Explanations. - """ - return inference_output - - def handle(self, data, context): - if self.local_rank != 0: - pass - start_time = time.time() - - self.context = context - metrics = self.context.metrics - - #run_pippy(self.initialize, context) - - is_profiler_enabled = os.environ.get("ENABLE_TORCH_PROFILER", None) - if is_profiler_enabled: - if PROFILER_AVAILABLE: - output, _ = self._infer_with_profiler(data=data) - else: - raise RuntimeError( - "Profiler is enabled but current version of torch does not support." - "Install torch>=1.8.1 to use profiler." - ) - else: - if self._is_describe(): - output = [self.describe_handle()] - else: - data_preprocess = self.preprocess(data) - - if not self._is_explain(): - output = self.inference(data_preprocess) - output = self.postprocess(output) - else: - output = self.explain_handle(data_preprocess, data) - - stop_time = time.time() - metrics.add_time( - "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" - ) - return output From 298869056151cd4b74fa43f8e5672edc28bd84eb Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 22 Mar 2023 16:04:41 -0700 Subject: [PATCH 33/47] revert to original doc for the example --- examples/Huggingface_Largemodels/Readme.md | 16 +++++----------- .../Huggingface_Largemodels/config.properties | 16 ++++++++-------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/examples/Huggingface_Largemodels/Readme.md b/examples/Huggingface_Largemodels/Readme.md index 636e3b0885..273c337a33 100644 --- a/examples/Huggingface_Largemodels/Readme.md +++ b/examples/Huggingface_Largemodels/Readme.md @@ -1,13 +1,7 @@ # Loading large Huggingface models with constrained resources using accelerate -This document briefs on serving large HF model with PiPPy. +This document briefs on serving large HG models with limited resource using accelerate. This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint). - -### Step 0: Install torchserve from src -```bash -python ts_scripts/install_from_src.py - -``` ### Step 1: Download model Login into huggingface hub with token by running the below command @@ -18,7 +12,7 @@ huggingface-cli login paste the token generated from huggingface hub. ```bash -python Download_model.py --model_name bigscience/bloom-1b1 +python Download_model.py --model_name bigscience/bloom-7b1 ``` The script prints the path where the model is downloaded as below. @@ -34,7 +28,7 @@ Navigate to the path got from the above script. In this example it is ```bash cd model/models--bigscience-bloom-7b1/snapshots/5546055f03398095e385d7dc625e636cc8910bf2/ -zip -r /home/ubuntu/serve/examples/Huggingface_Largemodels/model.zip * +zip -r /home/ubuntu/serve/examples/Huggingface_Largemodels//model.zip * cd - ``` @@ -44,7 +38,7 @@ cd - Navigate up to `Huggingface_Largemodels` directory. ```bash -torch-model-archiver --model-name bloom --version 1.0 --handler pippy_handler.py --extra-files model.zip,setup_config.json -r requirements.txt +torch-model-archiver --model-name bloom --version 1.0 --handler custom_handler.py --extra-files model.zip,setup_config.json -r requirements.txt ``` **__Note__**: Modifying setup_config.json @@ -64,7 +58,7 @@ mv bloom.mar model_store Update config.properties and start torchserve ```bash -torchserve --ncs --start --model-store model_store --models bloom.mar --ts-config config.properties +torchserve --start --ncs --ts-config config.properties ``` ### Step 5: Run inference diff --git a/examples/Huggingface_Largemodels/config.properties b/examples/Huggingface_Largemodels/config.properties index a01bfe2c28..f02628fde6 100644 --- a/examples/Huggingface_Largemodels/config.properties +++ b/examples/Huggingface_Largemodels/config.properties @@ -1,10 +1,10 @@ -nference_address=http://0.0.0.0:8080 +inference_address=http://0.0.0.0:8080 management_address=http://0.0.0.0:8081 -number_of_netty_threads=32 -job_queue_size=1000 -vmargs=-Xmx4g -XX:+ExitOnOutOfMemoryError -XX:+HeapDumpOnOutOfMemoryError -prefer_direct_buffer=True -default_response_timeout=300 -unregister_model_timeout=300 +metrics_address=http://0.0.0.0:8082 +enable_envvars_config=true install_py_dep_per_model=true -default_workers_per_model=1 +number_of_gpu=1 +load_models=all +max_response_size=655350000 +default_response_timeout=6000 +model_store=/home/ubuntu/serve/examples/Huggingface_Largemodels/model_store From dacc252a7fed73e2f32086fad721e1b112fa2476 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 22 Mar 2023 16:06:48 -0700 Subject: [PATCH 34/47] update doc link --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index f7e8139d47..cc08105104 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -266,7 +266,7 @@ models={\ }\ } ``` -As of version 0.8.0, TorchServe allows for model configuration using a YAML file embedded in the MAR file. This YAML file contains two distinct parts that determine how a model is configured: frontend parameters and backend parameters. (see [details](../model-archiver/README.md)) +As of version 0.8.0, TorchServe allows for model configuration using a YAML file embedded in the MAR file. This YAML file contains two distinct parts that determine how a model is configured: frontend parameters and backend parameters. (see [details](https://github.com/pytorch/serve/tree/master/model-archiver)) * The frontend parameters are controlled by TorchServe's frontend and specify the parameter name and default values. TorchServe now uses a priority order to determine the final value of a model's parameters in frontend. Specifically, the config.property file has the lowest priority, followed by the model configuration YAML file, and finally, the REST or gRPC model management API has the highest priority. From 5485d06888a65d070dc61a6606ace948b6f15a5f Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 22 Mar 2023 17:27:11 -0700 Subject: [PATCH 35/47] fmt --- .pre-commit-config.yaml | 10 ++--- model-archiver/model_archiver/arg_parser.py | 1 - .../manifest_components/manifest.py | 3 +- .../manifest_components/model.py | 38 +++++++++++-------- .../model_archiver/model_packaging.py | 2 +- .../model_archiver/model_packaging_utils.py | 3 +- .../test_example_scriptable_tokenzier.py | 3 -- ts/context.py | 2 +- ts/model_service_worker.py | 30 +++++++++------ ts/service.py | 11 +++--- ts/utils/util.py | 8 +++- 11 files changed, 63 insertions(+), 48 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7ee04a103..ec9f575678 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.4.0 hooks: - id: check-ast - id: check-builtin-literals @@ -18,23 +18,23 @@ repos: - id: check-vcs-permalinks - id: check-shebang-scripts-are-executable - repo: https://github.com/pre-commit/pygrep-hooks - rev: v1.9.0 + rev: v1.10.0 hooks: - id: python-check-mock-methods - id: python-no-log-warn - id: python-use-type-annotations - repo: https://github.com/hadialqattan/pycln - rev: v1.2.5 + rev: v2.1.3 hooks: - id: pycln args: [--all] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.1.0 hooks: - id: black additional_dependencies: ['click==8.0.4'] - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black"] diff --git a/model-archiver/model_archiver/arg_parser.py b/model-archiver/model_archiver/arg_parser.py index 7d92964d94..a925d22da4 100644 --- a/model-archiver/model_archiver/arg_parser.py +++ b/model-archiver/model_archiver/arg_parser.py @@ -17,7 +17,6 @@ class ArgParser(object): @staticmethod def export_model_args_parser(): - """Argument parser for torch-model-export""" parser_export = argparse.ArgumentParser( diff --git a/model-archiver/model_archiver/manifest_components/manifest.py b/model-archiver/model_archiver/manifest_components/manifest.py index 8ca037aa5e..f42c804616 100644 --- a/model-archiver/model_archiver/manifest_components/manifest.py +++ b/model-archiver/model_archiver/manifest_components/manifest.py @@ -8,17 +8,16 @@ class RuntimeType(Enum): - PYTHON = "python" PYTHON3 = "python3" + class Manifest(object): """ The main manifest object which gets written into the model archive as MANIFEST.json """ def __init__(self, runtime, model): - self.creation_time = datetime.now().strftime("%d/%m/%Y %H:%M:%S") self.runtime = RuntimeType(runtime) self.model = model diff --git a/model-archiver/model_archiver/manifest_components/model.py b/model-archiver/model_archiver/manifest_components/model.py index d496bc3fca..e6f72a2566 100644 --- a/model-archiver/model_archiver/manifest_components/model.py +++ b/model-archiver/model_archiver/manifest_components/model.py @@ -9,27 +9,35 @@ class Model(object): as the entry point into the service code through the handler property """ - def __init__(self, model_name, serialized_file, handler, model_file=None, model_version=None, - extensions=None, requirements_file=None, config_file=None): - + def __init__( + self, + model_name, + serialized_file, + handler, + model_file=None, + model_version=None, + extensions=None, + requirements_file=None, + config_file=None, + ): self.model_name = model_name self.serialized_file = None if serialized_file: - if sys.platform.startswith('win32') and serialized_file.find("\\") != -1: + if sys.platform.startswith("win32") and serialized_file.find("\\") != -1: self.serialized_file = serialized_file.split("\\")[-1] else: self.serialized_file = serialized_file.split("/")[-1] self.model_file = model_file self.model_version = model_version self.extensions = extensions - if sys.platform.startswith('win32') and handler.find("\\") != -1: + if sys.platform.startswith("win32") and handler.find("\\") != -1: self.handler = handler.split("\\")[-1] else: self.handler = handler.split("/")[-1] self.requirements_file = requirements_file self.config_file = None if config_file: - if sys.platform.startswith('win32') and config_file.find("\\") != -1: + if sys.platform.startswith("win32") and config_file.find("\\") != -1: self.config_file = config_file.split("\\")[-1] else: self.config_file = config_file.split("/")[-1] @@ -37,29 +45,29 @@ def __init__(self, model_name, serialized_file, handler, model_file=None, model_ self.model_dict = self.__to_dict__() def __to_dict__(self): - model_dict = dict() + model_dict = {} - model_dict['modelName'] = self.model_name + model_dict["modelName"] = self.model_name if self.serialized_file: - model_dict['serializedFile'] = self.serialized_file + model_dict["serializedFile"] = self.serialized_file - model_dict['handler'] = self.handler + model_dict["handler"] = self.handler if self.model_file: - model_dict['modelFile'] = self.model_file.split("/")[-1] + model_dict["modelFile"] = self.model_file.split("/")[-1] if self.model_version: - model_dict['modelVersion'] = self.model_version + model_dict["modelVersion"] = self.model_version if self.extensions: - model_dict['extensions'] = self.extensions + model_dict["extensions"] = self.extensions if self.requirements_file: - model_dict['requirementsFile'] = self.requirements_file.split("/")[-1] + model_dict["requirementsFile"] = self.requirements_file.split("/")[-1] if self.config_file: - model_dict['configFile'] = self.config_file + model_dict["configFile"] = self.config_file return model_dict diff --git a/model-archiver/model_archiver/model_packaging.py b/model-archiver/model_archiver/model_packaging.py index 023528fbd1..3304f6a4f1 100644 --- a/model-archiver/model_archiver/model_packaging.py +++ b/model-archiver/model_archiver/model_packaging.py @@ -38,7 +38,7 @@ def package_model(args, manifest): "handler": handler, "extra_files": extra_files, "requirements-file": requirements_file, - "config_file": config_file + "config_file": config_file, } model_path = ModelExportUtils.copy_artifacts(model_name, **artifact_files) diff --git a/model-archiver/model_archiver/model_packaging_utils.py b/model-archiver/model_archiver/model_packaging_utils.py index 57c9986858..cbda5e2ef2 100644 --- a/model-archiver/model_archiver/model_packaging_utils.py +++ b/model-archiver/model_archiver/model_packaging_utils.py @@ -107,7 +107,7 @@ def generate_model(modelargs): handler=modelargs.handler, model_version=modelargs.version, requirements_file=modelargs.requirements_file, - config_file=modelargs.config_file + config_file=modelargs.config_file, ) return model @@ -236,7 +236,6 @@ def archive( @staticmethod def archive_dir(path, dst, archive_format, model_name): - """ This method zips the dir and filters out some files based on a expression :param archive_format: diff --git a/test/pytest/test_example_scriptable_tokenzier.py b/test/pytest/test_example_scriptable_tokenzier.py index fcb4dce0cc..781885c462 100644 --- a/test/pytest/test_example_scriptable_tokenzier.py +++ b/test/pytest/test_example_scriptable_tokenzier.py @@ -216,7 +216,6 @@ def test_handler(monkeypatch, mocker, jit_file_path, test_file): # We need to recreate the handler to avoid running into https://github.com/pytorch/text/issues/1849 def create_and_call_handler(input_text): - from handler import CustomTextClassifier handler = CustomTextClassifier() @@ -251,7 +250,6 @@ def create_and_call_handler(input_text): def test_inference_with_untrained_model_and_sample_text(model_name, test_file): - with open(test_file, "rb") as f: response = requests.post( url=f"http://localhost:8080/predictions/{model_name}", data=f @@ -270,7 +268,6 @@ def test_inference_with_untrained_model_and_sample_text(model_name, test_file): def test_inference_with_untrained_model_and_empty_string(model_name): - data = "".encode("utf8") response = requests.post( diff --git a/ts/context.py b/ts/context.py index 16b100f6f5..9cbbe25b80 100644 --- a/ts/context.py +++ b/ts/context.py @@ -21,7 +21,7 @@ def __init__( mms_version, limit_max_image_pixels=True, metrics=None, - model_yaml_config=None + model_yaml_config=None, ): self.model_name = model_name self.manifest = manifest diff --git a/ts/model_service_worker.py b/ts/model_service_worker.py index 0afca19f12..d56ad377dc 100644 --- a/ts/model_service_worker.py +++ b/ts/model_service_worker.py @@ -21,11 +21,12 @@ DEBUG = False BENCHMARK = os.getenv("TS_BENCHMARK") BENCHMARK = BENCHMARK in ["True", "true", "TRUE"] -LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0)) -WORLD_SIZE = int(os.getenv('WORLD_SIZE', 0)) -WORLD_RANK = int(os.getenv('RANK', 0)) +LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0)) +WORLD_SIZE = int(os.getenv("WORLD_SIZE", 0)) +WORLD_RANK = int(os.getenv("RANK", 0)) LOCAL_WORLD_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", 0)) + class TorchModelServiceWorker(object): """ Backend worker to handle Model Server's python service code @@ -44,9 +45,14 @@ def __init__( if s_type == "unix": if s_name is None: raise ValueError("Wrong arguments passed. No socket name given.") - s_name_parts = s_name.rsplit('.', 1) - logging.info("s_name_part0=%s, s_name_part1=%s, pid=%d", s_name_parts[0], s_name_parts[1], os.getpid()) - s_name_new = s_name_parts[0] + '.' + str(int(s_name_parts[1]) + LOCAL_RANK) + s_name_parts = s_name.rsplit(".", 1) + logging.info( + "s_name_part0=%s, s_name_part1=%s, pid=%d", + s_name_parts[0], + s_name_parts[1], + os.getpid(), + ) + s_name_new = s_name_parts[0] + "." + str(int(s_name_parts[1]) + LOCAL_RANK) self.sock_name, self.port = s_name_new, -1 try: os.remove(s_name_new) @@ -55,7 +61,7 @@ def __init__( raise RuntimeError( "socket already in use: {}.".format(s_name_new) ) from e - + elif s_type == "tcp": self.sock_name = host_addr if host_addr is not None else "127.0.0.1" if port_num is None: @@ -63,7 +69,7 @@ def __init__( self.port = port_num + LOCAL_RANK else: raise ValueError("Incomplete data provided") - + logging.info("Listening on port: %s", s_name) socket_family = socket.AF_INET if s_type == "tcp" else socket.AF_UNIX self.sock = socket.socket(socket_family, socket.SOCK_STREAM) @@ -71,7 +77,9 @@ def __init__( if self.metrics_cache: self.metrics_cache.initialize_cache() else: - raise RuntimeError(f"Failed to initialize metrics from file {metrics_config}") + raise RuntimeError( + f"Failed to initialize metrics from file {metrics_config}" + ) def load_model(self, load_model_request): """ @@ -128,7 +136,7 @@ def load_model(self, load_model_request): batch_size, envelope, limit_max_image_pixels, - self.metrics_cache + self.metrics_cache, ) logging.debug("Model %s loaded.", model_name) @@ -211,7 +219,7 @@ def run_server(self): socket_name = args.sock_name sock_type = args.sock_type host = args.host - port = args.port + port = args.port metrics_config = args.metrics_config if BENCHMARK: diff --git a/ts/service.py b/ts/service.py index 25bbd39b8f..e95785ed7a 100644 --- a/ts/service.py +++ b/ts/service.py @@ -9,8 +9,7 @@ import ts from ts.context import Context, RequestProcessor from ts.protocol.otf_message_handler import create_predict_response -from ts.utils.util import PredictionException -from ts.utils.util import get_yaml_config +from ts.utils.util import PredictionException, get_yaml_config PREDICTION_METRIC = "PredictionTime" logger = logging.getLogger(__name__) @@ -32,12 +31,14 @@ def __init__( limit_max_image_pixels=True, metrics_cache=None, ): - model_yaml_config = dict() + model_yaml_config = {} if manifest is not None and "model" in manifest: model = manifest["model"] if "configFile" in model: model_yaml_config_file = model["configFile"] - model_yaml_config = get_yaml_config(os.path.join(model_dir, model_yaml_config_file)) + model_yaml_config = get_yaml_config( + os.path.join(model_dir, model_yaml_config_file) + ) if "deviceIds" in model_yaml_config and "parallelLevel" in model_yaml_config: if int(model_yaml_config["parallelLevel"]) == 1: @@ -52,7 +53,7 @@ def __init__( ts.__version__, limit_max_image_pixels, metrics_cache, - model_yaml_config + model_yaml_config, ) self._entry_point = entry_point diff --git a/ts/utils/util.py b/ts/utils/util.py index 208e3c199d..5f0b1a8e18 100644 --- a/ts/utils/util.py +++ b/ts/utils/util.py @@ -8,8 +8,10 @@ import logging import os import re + import yaml + class PT2Backend(str, enum.Enum): EAGER = "eager" AOT_EAGER = "aot_eager" @@ -135,12 +137,14 @@ def map_class_to_label(probs, mapping=None, lbl_classes=None): return results + def get_yaml_config(yaml_file_path): - config = dict() - with open(yaml_file_path, 'r') as file: + config = {} + with open(yaml_file_path, "r") as file: config = yaml.safe_load(file) return config + class PredictionException(Exception): def __init__(self, message, error_code=500): self.message = message From 4877c56267494ca20aa6ae37e4ba39c8f1d6f856 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 23 Mar 2023 17:19:58 -0700 Subject: [PATCH 36/47] fmt --- .../manifest_components/model.py | 10 +- .../tests/unit_tests/test_model_packaging.py | 73 +++++--- .../unit_tests/test_model_packaging_utils.py | 175 +++++++++++------- .../test_example_near_real_time_video.py | 2 +- .../test_example_scriptable_tokenzier.py | 2 +- test/pytest/test_example_torchrec_dlrm.py | 2 +- .../unit_tests/test_model_service_worker.py | 6 +- 7 files changed, 171 insertions(+), 99 deletions(-) diff --git a/model-archiver/model_archiver/manifest_components/model.py b/model-archiver/model_archiver/manifest_components/model.py index e6f72a2566..1ff2517ba0 100644 --- a/model-archiver/model_archiver/manifest_components/model.py +++ b/model-archiver/model_archiver/manifest_components/model.py @@ -1,5 +1,6 @@ # pylint: disable=missing-docstring import json +import os import sys @@ -30,10 +31,11 @@ def __init__( self.model_file = model_file self.model_version = model_version self.extensions = extensions - if sys.platform.startswith("win32") and handler.find("\\") != -1: - self.handler = handler.split("\\")[-1] - else: - self.handler = handler.split("/")[-1] + # if sys.platform.startswith("win32") and handler.find("\\") != -1: + # self.handler = handler.split("\\")[-1] + # else: + # self.handler = handler.split("/")[-1] + self.handler = os.path.basename(handler) self.requirements_file = requirements_file self.config_file = None if config_file: diff --git a/model-archiver/model_archiver/tests/unit_tests/test_model_packaging.py b/model-archiver/model_archiver/tests/unit_tests/test_model_packaging.py index 6297318704..eef8dd0051 100644 --- a/model-archiver/model_archiver/tests/unit_tests/test_model_packaging.py +++ b/model-archiver/model_archiver/tests/unit_tests/test_model_packaging.py @@ -1,13 +1,11 @@ - - +import sys from collections import namedtuple import pytest -import sys from mock import MagicMock -sys.modules['shutil'] = MagicMock() -sys.modules['shutil.rmtree'] = MagicMock() +sys.modules["shutil"] = MagicMock() +sys.modules["shutil.rmtree"] = MagicMock() from model_archiver.manifest_components.manifest import RuntimeType from model_archiver.model_packaging import generate_model_archive, package_model @@ -16,7 +14,6 @@ # noinspection PyClassHasNoInit class TestModelPackaging: - class Namespace: def __init__(self, **kwargs): self.__dict__.update(kwargs) @@ -24,26 +21,41 @@ def __init__(self, **kwargs): def update(self, **kwargs): self.__dict__.update(kwargs) - model_name = 'my-model' - model_file = 'my-model/' - serialized_file = 'my-model/' - handler = 'a.py::my-awesome-func' - export_path = '/Users/dummyUser/' - version = '1.0' + model_name = "my-model" + model_file = "my-model/" + serialized_file = "my-model/" + handler = "a.py::my-awesome-func" + export_path = "/Users/dummyUser/" + version = "1.0" requirements_file = "requirements.txt" + config_file = None source_vocab = None - args = Namespace(model_name=model_name, handler=handler, runtime=RuntimeType.PYTHON.value, model_file=model_file, - serialized_file=serialized_file, extra_files=None, export_path=export_path, force=False, - archive_format="default", convert=False, version=version, source_vocab=source_vocab, - requirements_file=requirements_file) + args = Namespace( + model_name=model_name, + handler=handler, + runtime=RuntimeType.PYTHON.value, + model_file=model_file, + serialized_file=serialized_file, + extra_files=None, + export_path=export_path, + force=False, + archive_format="default", + convert=False, + version=version, + source_vocab=source_vocab, + requirements_file=requirements_file, + config_file=None, + ) @pytest.fixture() def patches(self, mocker): - Patches = namedtuple('Patches', ['arg_parse', 'export_utils', 'export_method']) - patches = Patches(mocker.patch('model_archiver.model_packaging.ArgParser'), - mocker.patch('model_archiver.model_packaging.ModelExportUtils'), - mocker.patch('model_archiver.model_packaging.package_model')) + Patches = namedtuple("Patches", ["arg_parse", "export_utils", "export_method"]) + patches = Patches( + mocker.patch("model_archiver.model_packaging.ArgParser"), + mocker.patch("model_archiver.model_packaging.ModelExportUtils"), + mocker.patch("model_archiver.model_packaging.package_model"), + ) return patches @@ -53,8 +65,11 @@ def test_gen_model_archive(self, patches): patches.export_method.assert_called() def test_export_model_method(self, patches): - patches.export_utils.check_mar_already_exists.return_value = '/Users/dummyUser/' - patches.export_utils.check_custom_model_types.return_value = '/Users/dummyUser', ['a.txt', 'b.txt'] + patches.export_utils.check_mar_already_exists.return_value = "/Users/dummyUser/" + patches.export_utils.check_custom_model_types.return_value = ( + "/Users/dummyUser", + ["a.txt", "b.txt"], + ) patches.export_utils.zip.return_value = None package_model(self.args, ModelExportUtils.generate_manifest_json(self.args)) @@ -63,8 +78,11 @@ def test_export_model_method(self, patches): def test_export_model_method_tar(self, patches): self.args.update(archive_format="tar") - patches.export_utils.check_mar_already_exists.return_value = '/Users/dummyUser/' - patches.export_utils.check_custom_model_types.return_value = '/Users/dummyUser', ['a.txt', 'b.txt'] + patches.export_utils.check_mar_already_exists.return_value = "/Users/dummyUser/" + patches.export_utils.check_custom_model_types.return_value = ( + "/Users/dummyUser", + ["a.txt", "b.txt"], + ) patches.export_utils.zip.return_value = None package_model(self.args, ModelExportUtils.generate_manifest_json(self.args)) @@ -73,8 +91,11 @@ def test_export_model_method_tar(self, patches): def test_export_model_method_noarchive(self, patches): self.args.update(archive_format="no-archive") - patches.export_utils.check_mar_already_exists.return_value = '/Users/dummyUser/' - patches.export_utils.check_custom_model_types.return_value = '/Users/dummyUser', ['a.txt', 'b.txt'] + patches.export_utils.check_mar_already_exists.return_value = "/Users/dummyUser/" + patches.export_utils.check_custom_model_types.return_value = ( + "/Users/dummyUser", + ["a.txt", "b.txt"], + ) patches.export_utils.zip.return_value = None package_model(self.args, ModelExportUtils.generate_manifest_json(self.args)) diff --git a/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py b/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py index 36a15fd2da..e54f4c1fba 100644 --- a/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py +++ b/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py @@ -1,13 +1,11 @@ - - import json import platform +from collections import namedtuple import pytest -from collections import namedtuple -from model_archiver.model_packaging_utils import ModelExportUtils from model_archiver.manifest_components.manifest import RuntimeType from model_archiver.model_archiver_error import ModelArchiverError +from model_archiver.model_packaging_utils import ModelExportUtils # noinspection PyClassHasNoInit @@ -15,58 +13,72 @@ def _validate_mar(patches): if platform.system() == "Windows": patches.path_exists.assert_called_once_with("/Users/dummyUser\\some-model.mar") else: - patches.path_exists.assert_called_once_with('/Users/dummyUser/some-model.mar') + patches.path_exists.assert_called_once_with("/Users/dummyUser/some-model.mar") + # noinspection PyClassHasNoInit class TestExportModelUtils: - # noinspection PyClassHasNoInit class TestMarExistence: - @pytest.fixture() def patches(self, mocker): - Patches = namedtuple('Patches', ['getcwd', 'path_exists']) - patches = Patches(mocker.patch('os.getcwd'), mocker.patch('os.path.exists')) - patches.getcwd.return_value = '/Users/dummyUser' + Patches = namedtuple("Patches", ["getcwd", "path_exists"]) + patches = Patches(mocker.patch("os.getcwd"), mocker.patch("os.path.exists")) + patches.getcwd.return_value = "/Users/dummyUser" return patches def test_export_file_is_none(self, patches): patches.path_exists.return_value = False - ret_val = ModelExportUtils.check_mar_already_exists('some-model', None, False) + ret_val = ModelExportUtils.check_mar_already_exists( + "some-model", None, False + ) _validate_mar(patches) assert ret_val == "/Users/dummyUser" def test_export_file_is_not_none(self, patches): patches.path_exists.return_value = False - ModelExportUtils.check_mar_already_exists('some-model', '/Users/dummyUser/', False) - patches.path_exists.assert_called_once_with('/Users/dummyUser/some-model.mar') + ModelExportUtils.check_mar_already_exists( + "some-model", "/Users/dummyUser/", False + ) + patches.path_exists.assert_called_once_with( + "/Users/dummyUser/some-model.mar" + ) def test_export_file_already_exists_with_override(self, patches): patches.path_exists.return_value = True - ModelExportUtils.check_mar_already_exists('some-model', None, True) + ModelExportUtils.check_mar_already_exists("some-model", None, True) _validate_mar(patches) def test_export_file_already_exists_with_override_false(self, patches): patches.path_exists.return_value = True with pytest.raises(ModelArchiverError): - ModelExportUtils.check_mar_already_exists('some-model', None, False) + ModelExportUtils.check_mar_already_exists("some-model", None, False) _validate_mar(patches) def test_export_file_is_none_tar(self, patches): patches.path_exists.return_value = False - ret_val = ModelExportUtils.check_mar_already_exists('some-model', None, False, archive_format='tgz') + ret_val = ModelExportUtils.check_mar_already_exists( + "some-model", None, False, archive_format="tgz" + ) if platform.system() == "Windows": - patches.path_exists.assert_called_once_with("/Users/dummyUser\\some-model.tar.gz") + patches.path_exists.assert_called_once_with( + "/Users/dummyUser\\some-model.tar.gz" + ) else: - patches.path_exists.assert_called_once_with("/Users/dummyUser/some-model.tar.gz") + patches.path_exists.assert_called_once_with( + "/Users/dummyUser/some-model.tar.gz" + ) assert ret_val == "/Users/dummyUser" class TestArchiveTypes: def test_archive_types(self): - from model_archiver.model_packaging_utils import archiving_options as ar_opts + from model_archiver.model_packaging_utils import ( + archiving_options as ar_opts, + ) + assert ar_opts.get("tgz") == ".tar.gz" assert ar_opts.get("no-archive") == "" assert ar_opts.get("default") == ".mar" @@ -74,52 +86,51 @@ def test_archive_types(self): # noinspection PyClassHasNoInit class TestCustomModelTypes: - - model_path = '/Users/dummyUser' + model_path = "/Users/dummyUser" @pytest.fixture() def patches(self, mocker): - Patches = namedtuple('Patches', ['utils', 'listdir']) - patch = Patches(mocker.patch('model_archiver.model_packaging_utils.ModelExportUtils'), - mocker.patch('os.listdir')) + Patches = namedtuple("Patches", ["utils", "listdir"]) + patch = Patches( + mocker.patch("model_archiver.model_packaging_utils.ModelExportUtils"), + mocker.patch("os.listdir"), + ) - patch.listdir.return_value = {'a', 'b', 'c'} + patch.listdir.return_value = {"a", "b", "c"} return patch # noinspection PyClassHasNoInit class TestFindUnique: - def test_with_count_zero(self): - files = ['a.txt', 'b.txt', 'c.txt'] - suffix = '.mxnet' + files = ["a.txt", "b.txt", "c.txt"] + suffix = ".mxnet" val = ModelExportUtils.find_unique(files, suffix) assert val is None def test_with_count_one(self): - files = ['a.mxnet', 'b.txt', 'c.txt'] - suffix = '.mxnet' + files = ["a.mxnet", "b.txt", "c.txt"] + suffix = ".mxnet" val = ModelExportUtils.find_unique(files, suffix) - assert val == 'a.mxnet' + assert val == "a.mxnet" def test_with_exit(self): - files = ['a.onnx', 'b.onnx', 'c.txt'] - suffix = '.onnx' + files = ["a.onnx", "b.onnx", "c.txt"] + suffix = ".onnx" with pytest.raises(ModelArchiverError): ModelExportUtils.find_unique(files, suffix) # noinspection PyClassHasNoInit class TestCleanTempFiles: - @pytest.fixture() def patches(self, mocker): - Patches = namedtuple('Patches', ['remove']) - patches = Patches(mocker.patch('os.remove')) + Patches = namedtuple("Patches", ["remove"]) + patches = Patches(mocker.patch("os.remove")) patches.remove.return_value = True return patches def test_clean_call(self, patches): - temp_files = ['a', 'b', 'c'] + temp_files = ["a", "b", "c"] ModelExportUtils.clean_temp_files(temp_files) patches.remove.assert_called() @@ -127,21 +138,27 @@ def test_clean_call(self, patches): # noinspection PyClassHasNoInit class TestGenerateManifestProps: - class Namespace: def __init__(self, **kwargs): self.__dict__.update(kwargs) - model_name = 'my-model' - handler = 'a.py::my-awesome-func' - serialized_file = 'model.pt' - model_file = 'model.pt' + model_name = "my-model" + handler = "a.py::my-awesome-func" + serialized_file = "model.pt" + model_file = "model.pt" version = "1.0" requirements_file = "requirements.txt" - args = Namespace(model_name=model_name, handler=handler, runtime=RuntimeType.PYTHON.value, - serialized_file=serialized_file, model_file=model_file, version=version, - requirements_file=requirements_file) + args = Namespace( + model_name=model_name, + handler=handler, + runtime=RuntimeType.PYTHON.value, + serialized_file=serialized_file, + model_file=model_file, + version=version, + requirements_file=requirements_file, + config_file=None, + ) def test_model(self): mod = ModelExportUtils.generate_model(self.args) @@ -151,52 +168,84 @@ def test_model(self): def test_manifest_json(self): manifest = ModelExportUtils.generate_manifest_json(self.args) manifest_json = json.loads(manifest) - assert manifest_json['runtime'] == RuntimeType.PYTHON.value - assert 'model' in manifest_json - assert 'license' not in manifest_json + assert manifest_json["runtime"] == RuntimeType.PYTHON.value + assert "model" in manifest_json + assert "license" not in manifest_json # noinspection PyClassHasNoInit class TestModelNameRegEx: - def test_regex_pass(self): - model_names = ['my-awesome-model', 'Aa.model', 'a', 'aA.model', 'a1234.model', 'a-A-A.model', '123-abc'] + model_names = [ + "my-awesome-model", + "Aa.model", + "a", + "aA.model", + "a1234.model", + "a-A-A.model", + "123-abc", + ] for m in model_names: ModelExportUtils.check_model_name_regex_or_exit(m) def test_regex_fail(self): - model_names = ['abc%', '123$abc', 'abc!123', '@123', '(model', 'mdoel)', - '12*model-a.model', '##.model', '-.model'] + model_names = [ + "abc%", + "123$abc", + "abc!123", + "@123", + "(model", + "mdoel)", + "12*model-a.model", + "##.model", + "-.model", + ] for m in model_names: with pytest.raises(ModelArchiverError): ModelExportUtils.check_model_name_regex_or_exit(m) # noinspection PyClassHasNoInit class TestFileFilter: - - files_to_exclude = {'abc.onnx'} + files_to_exclude = {"abc.onnx"} def test_with_return_false(self): - assert ModelExportUtils.file_filter('abc.onnx', self.files_to_exclude) is False + assert ( + ModelExportUtils.file_filter("abc.onnx", self.files_to_exclude) is False + ) def test_with_pyc(self): - assert ModelExportUtils.file_filter('abc.pyc', self.files_to_exclude) is False + assert ( + ModelExportUtils.file_filter("abc.pyc", self.files_to_exclude) is False + ) def test_with_ds_store(self): - assert ModelExportUtils.file_filter('.DS_Store', self.files_to_exclude) is False + assert ( + ModelExportUtils.file_filter(".DS_Store", self.files_to_exclude) + is False + ) def test_with_return_true(self): - assert ModelExportUtils.file_filter('abc.mxnet', self.files_to_exclude) is True + assert ( + ModelExportUtils.file_filter("abc.mxnet", self.files_to_exclude) is True + ) # noinspection PyClassHasNoInit class TestDirectoryFilter: - - unwanted_dirs = {'__MACOSX', '__pycache__'} + unwanted_dirs = {"__MACOSX", "__pycache__"} def test_with_unwanted_dirs(self): - assert ModelExportUtils.directory_filter('__MACOSX', self.unwanted_dirs) is False + assert ( + ModelExportUtils.directory_filter("__MACOSX", self.unwanted_dirs) + is False + ) def test_with_starts_with_dot(self): - assert ModelExportUtils.directory_filter('.gitignore', self.unwanted_dirs) is False + assert ( + ModelExportUtils.directory_filter(".gitignore", self.unwanted_dirs) + is False + ) def test_with_return_true(self): - assert ModelExportUtils.directory_filter('my-model', self.unwanted_dirs) is True + assert ( + ModelExportUtils.directory_filter("my-model", self.unwanted_dirs) + is True + ) diff --git a/test/pytest/test_example_near_real_time_video.py b/test/pytest/test_example_near_real_time_video.py index 093c9f97d7..18548a7e0b 100644 --- a/test/pytest/test_example_near_real_time_video.py +++ b/test/pytest/test_example_near_real_time_video.py @@ -61,10 +61,10 @@ def create_mar_file(work_dir, session_mocker, model_archiver): ).as_posix(), export_path=work_dir, requirements_file=None, - config_file=None, runtime="python", force=False, archive_format="default", + config_file=None, ) mock = session_mocker.MagicMock() diff --git a/test/pytest/test_example_scriptable_tokenzier.py b/test/pytest/test_example_scriptable_tokenzier.py index 781885c462..ca1909edc6 100644 --- a/test/pytest/test_example_scriptable_tokenzier.py +++ b/test/pytest/test_example_scriptable_tokenzier.py @@ -157,10 +157,10 @@ def create_mar_file(work_dir, session_mocker, jit_file_path, model_archiver): extra_files=os.path.join(EXAMPLE_ROOT_DIR, "index_to_name.json"), export_path=work_dir, requirements_file=None, - config_file=None, runtime="python", force=False, archive_format="default", + config_file=None, ) mock = session_mocker.MagicMock() diff --git a/test/pytest/test_example_torchrec_dlrm.py b/test/pytest/test_example_torchrec_dlrm.py index 505f4ecf9c..e4fef7e240 100644 --- a/test/pytest/test_example_torchrec_dlrm.py +++ b/test/pytest/test_example_torchrec_dlrm.py @@ -101,10 +101,10 @@ def create_mar_file(work_dir, session_mocker, serialized_file, model_archiver): + EXAMPLE_ROOT_DIR.joinpath("dlrm_model_config.py").as_posix(), export_path=work_dir, requirements_file=None, - config_file=None, runtime="python", force=False, archive_format="default", + config_file=None, ) mock = session_mocker.MagicMock() diff --git a/ts/tests/unit_tests/test_model_service_worker.py b/ts/tests/unit_tests/test_model_service_worker.py index b0bd4ac7cb..a17ede9650 100644 --- a/ts/tests/unit_tests/test_model_service_worker.py +++ b/ts/tests/unit_tests/test_model_service_worker.py @@ -43,7 +43,7 @@ def socket_patches(mocker): def model_service_worker(socket_patches): if not sys.platform.startswith("win"): model_service_worker = TorchModelServiceWorker( - "unix", "my-socket", None, None, metrics_config_path + "unix", "my-socket.9999", None, None, metrics_config_path ) else: model_service_worker = TorchModelServiceWorker( @@ -59,7 +59,7 @@ def model_service_worker(socket_patches): sys.platform.startswith("win"), reason="Skipping linux/darwin specific test cases" ) class TestInit: - socket_name = "sampleSocketName" + socket_name = "sampleSocketName.9999" def test_missing_socket_name(self): with pytest.raises(ValueError, match="Incomplete data provided.*"): @@ -72,7 +72,7 @@ def test_socket_in_use(self, mocker): path_exists.return_value = True with pytest.raises( - Exception, match=r".*socket already in use: sampleSocketName.*" + Exception, match=r".*socket already in use: sampleSocketName.9999.*" ): TorchModelServiceWorker( "unix", self.socket_name, None, None, metrics_config_path From b3544cd0820d0834968c171e1826b89210ef0d71 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 23 Mar 2023 23:02:45 -0700 Subject: [PATCH 37/47] use basename --- .../model_archiver/manifest_components/model.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/model-archiver/model_archiver/manifest_components/model.py b/model-archiver/model_archiver/manifest_components/model.py index 1ff2517ba0..aed91d8b94 100644 --- a/model-archiver/model_archiver/manifest_components/model.py +++ b/model-archiver/model_archiver/manifest_components/model.py @@ -24,26 +24,15 @@ def __init__( self.model_name = model_name self.serialized_file = None if serialized_file: - if sys.platform.startswith("win32") and serialized_file.find("\\") != -1: - self.serialized_file = serialized_file.split("\\")[-1] - else: - self.serialized_file = serialized_file.split("/")[-1] + self.serialized_file = os.path.basename(serialized_file) self.model_file = model_file self.model_version = model_version self.extensions = extensions - # if sys.platform.startswith("win32") and handler.find("\\") != -1: - # self.handler = handler.split("\\")[-1] - # else: - # self.handler = handler.split("/")[-1] self.handler = os.path.basename(handler) self.requirements_file = requirements_file self.config_file = None if config_file: - if sys.platform.startswith("win32") and config_file.find("\\") != -1: - self.config_file = config_file.split("\\")[-1] - else: - self.config_file = config_file.split("/")[-1] - + self.config_file = os.path.basename(config_file) self.model_dict = self.__to_dict__() def __to_dict__(self): From 35857b71b3e9ee4883093c8db07922c0089988f5 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 23 Mar 2023 23:05:18 -0700 Subject: [PATCH 38/47] precommit fmt --- model-archiver/model_archiver/manifest_components/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/model-archiver/model_archiver/manifest_components/model.py b/model-archiver/model_archiver/manifest_components/model.py index aed91d8b94..e250ad61bd 100644 --- a/model-archiver/model_archiver/manifest_components/model.py +++ b/model-archiver/model_archiver/manifest_components/model.py @@ -1,7 +1,6 @@ # pylint: disable=missing-docstring import json import os -import sys class Model(object): From 94f03f4ce6db7a04d90e2580d3c4f6c32aeb33b0 Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 24 Mar 2023 22:26:00 -0700 Subject: [PATCH 39/47] check model config input --- .../serve/archive/model/ModelConfig.java | 29 +++++++++++++++++++ .../java/org/pytorch/serve/ModelServer.java | 10 +++++++ .../java/org/pytorch/serve/wlm/Model.java | 7 +++++ 3 files changed, 46 insertions(+) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index 87ec11fd02..dc859fbb87 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -6,8 +6,12 @@ import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class ModelConfig { + private static final Logger logger = LoggerFactory.getLogger(ModelConfig.class); + private int minWorkers; private int maxWorkers; private int batchSize; @@ -62,6 +66,10 @@ public int getMinWorkers() { } public void setMinWorkers(int minWorkers) { + if (minWorkers < 0) { + logger.warn("Invalid minWorkers:{}", minWorkers); + return; + } this.minWorkers = minWorkers; } @@ -70,6 +78,10 @@ public int getMaxWorkers() { } public void setMaxWorkers(int maxWorkers) { + if (maxWorkers < 0) { + logger.warn("Invalid maxWorkers:{}", maxWorkers); + return; + } this.maxWorkers = maxWorkers; } @@ -78,6 +90,10 @@ public int getBatchSize() { } public void setBatchSize(int batchSize) { + if (batchSize <= 0) { + logger.warn("Invalid batchSize:{}", batchSize); + return; + } this.batchSize = batchSize; } @@ -86,6 +102,10 @@ public int getMaxBatchDelay() { } public void setMaxBatchDelay(int maxBatchDelay) { + if (maxBatchDelay < 0) { + logger.warn("Invalid maxBatchDelay:{}", maxBatchDelay); + return; + } this.maxBatchDelay = maxBatchDelay; } @@ -94,6 +114,10 @@ public int getResponseTimeout() { } public void setResponseTimeout(int responseTimeout) { + if (responseTimeout <= 0) { + logger.warn("Invalid responseTimeout:{}", responseTimeout); + return; + } this.responseTimeout = responseTimeout; } @@ -114,6 +138,11 @@ public int getParallelLevel() { } public void setParallelLevel(int parallelLevel) { + if (parallelLevel <= 0) { + logger.warn("Invalid parallelLevel:{}, set as 1", parallelLevel); + this.parallelLevel = 1; + return; + } this.parallelLevel = parallelLevel; } diff --git a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java index b835c5e211..e2d9ec7031 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java @@ -212,6 +212,11 @@ private void initModelStore() throws InvalidSnapshotException, IOException { if (marMinWorkers > 0 && marMaxWorkers >= marMinWorkers) { minWorkers = marMinWorkers; maxWorkers = marMaxWorkers; + } else { + logger.warn( + "Invalid model config in mar, minWorkers:{}, maxWorkers:{}", + marMinWorkers, + marMaxWorkers); } } modelManager.updateModel( @@ -284,6 +289,11 @@ private void initModelStore() throws InvalidSnapshotException, IOException { if (marMinWorkers > 0 && marMaxWorkers >= marMinWorkers) { minWorkers = marMinWorkers; maxWorkers = marMaxWorkers; + } else { + logger.warn( + "Invalid model config in mar, minWorkers:{}, maxWorkers:{}", + marMinWorkers, + marMaxWorkers); } } modelManager.updateModel( diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index cacfa09c2b..015a8055b1 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -78,6 +78,13 @@ public Model(ModelArchive modelArchive, int queueSize) { : deviceType; } deviceIds = modelArchive.getModelConfig().getDeviceIds(); + for (Integer deviceId : deviceIds) { + if (deviceId < 0 || deviceId >= ConfigManager.getInstance().getNumberOfGpu()) { + logger.warn("Invalid deviceId:{}, ignore deviceIds list", deviceId); + deviceIds = null; + break; + } + } } else { batchSize = 1; maxBatchDelay = 100; From de1efb34ef2730e3b2aaabcc3841c350d1b417bf Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 25 Mar 2023 14:01:36 -0700 Subject: [PATCH 40/47] add Model ConfigTest --- .../serve/archive/model/ModelConfig.java | 120 +++++++++++++----- .../serve/archive/model/ModelConfigTest.java | 48 +++++++ .../test/resources/modelConfig/invalid.yaml | 10 ++ .../src/test/resources/modelConfig/valid.yaml | 10 ++ frontend/archive/testng.xml | 1 + 5 files changed, 157 insertions(+), 32 deletions(-) create mode 100644 frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java create mode 100644 frontend/archive/src/test/resources/modelConfig/invalid.yaml create mode 100644 frontend/archive/src/test/resources/modelConfig/valid.yaml diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index dc859fbb87..e82873db64 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -1,11 +1,10 @@ package org.pytorch.serve.archive.model; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; -import java.util.stream.Stream; +import java.util.NoSuchElementException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,31 +27,68 @@ public static ModelConfig build(Map yamlMap) { (k, v) -> { switch (k) { case "minWorkers": - modelConfig.setMinWorkers((int) v); + if (v instanceof Integer) { + modelConfig.setMinWorkers((int) v); + } else { + logger.warn("Invalid minWorkers: {}, should be integer", v); + } break; case "maxWorkers": - modelConfig.setMaxWorkers((int) v); + if (v instanceof Integer) { + modelConfig.setMaxWorkers((int) v); + } else { + logger.warn("Invalid maxWorkers: {}, should be integer", v); + } break; case "batchSize": - modelConfig.setBatchSize((int) v); + if (v instanceof Integer) { + modelConfig.setBatchSize((int) v); + } else { + logger.warn("Invalid batchSize: {}, should be integer", v); + } break; case "maxBatchDelay": - modelConfig.setMaxBatchDelay((int) v); + if (v instanceof Integer) { + modelConfig.setMaxBatchDelay((int) v); + } else { + logger.warn("Invalid maxBatchDelay: {}, should be integer", v); + } break; case "responseTimeout": - modelConfig.setResponseTimeout((int) v); + if (v instanceof Integer) { + modelConfig.setResponseTimeout((int) v); + } else { + logger.warn("Invalid responseTimeout: {}, should be integer", v); + } break; case "deviceType": - modelConfig.setDeviceType((String) v); + if (v instanceof String) { + modelConfig.setDeviceType((String) v); + } else { + logger.warn("Invalid deviceType: {}, should be cpu, or gpu", v); + } break; case "parallelLevel": - modelConfig.setParallelLevel((int) v); + if (v instanceof Integer) { + modelConfig.setParallelLevel((int) v); + } else { + logger.warn("Invalid parallelLevel: {}, should be integer >= 1", v); + } break; case "parallelType": - modelConfig.setParallelMode((String) v); + if (v instanceof String) { + modelConfig.setParallelMode((String) v); + } else { + logger.warn( + "Invalid parallelType: {}, should be pp, tp,or pptp", v); + } break; case "deviceIds": - modelConfig.setDeviceIds(v); + if (v instanceof List) { + modelConfig.setDeviceIds((List) v); + } else { + logger.warn("Invalid deviceIds: {}, should be list of integer", v); + } break; default: break; @@ -125,12 +161,17 @@ public List getDeviceIds() { return deviceIds; } - public void setDeviceIds(Object deviceIds) { - this.deviceIds = - Stream.of(deviceIds) - .map(Object::toString) - .map(Integer::parseInt) - .collect(Collectors.toList()); + public void setDeviceIds(List deviceIds) { + this.deviceIds = new ArrayList<>(); + for (int i = 0; i < deviceIds.size(); i++) { + if (deviceIds.get(i) instanceof Integer) { + this.deviceIds.add((int) deviceIds.get(i)); + } else { + logger.warn("Invalid deviceIds:{},", deviceIds.get(i)); + this.deviceIds = null; + break; + } + } } public int getParallelLevel() { @@ -147,7 +188,7 @@ public void setParallelLevel(int parallelLevel) { } public void setParallelMode(String parallelMode) { - this.parallelType = ParallelType.get(parallelMode).get(); + this.parallelType = ParallelType.get(parallelMode); } public ParallelType getParallelType() { @@ -155,7 +196,7 @@ public ParallelType getParallelType() { } public void setDeviceType(String deviceType) { - this.deviceType = DeviceType.get(deviceType).get(); + this.deviceType = DeviceType.get(deviceType); } public DeviceType getDeviceType() { @@ -171,40 +212,55 @@ public enum ParallelType { private String type; ParallelType(String type) { - this.type = type; + this.type = type.toLowerCase(); } public String getParallelType() { return type; } - public static Optional get(String parallelType) { - return Arrays.stream(ParallelType.values()) - .filter(t -> t.type.equals(parallelType)) - .findFirst(); + public static ParallelType get(String parallelType) { + ParallelType pType = NONE; + try { + pType = + Arrays.stream(ParallelType.values()) + .filter(t -> t.type.equals(parallelType.toLowerCase())) + .findFirst() + .get(); + } catch (NoSuchElementException e) { + logger.warn("Invalid ParallelType:{}", parallelType, e); + } + return pType; } } public enum DeviceType { NONE(""), CPU("cpu"), - GPU("gpu"), - NEURON("neuron"); + GPU("gpu"); private String type; DeviceType(String type) { - this.type = type; + this.type = type.toLowerCase(); } public String getDeviceType() { return type; } - public static Optional get(String deviceType) { - return Arrays.stream(DeviceType.values()) - .filter(t -> t.type.equals(deviceType)) - .findFirst(); + public static DeviceType get(String deviceType) { + DeviceType dType = DeviceType.NONE; + try { + dType = + Arrays.stream(DeviceType.values()) + .filter(t -> t.type.equals(deviceType.toLowerCase())) + .findFirst() + .get(); + } catch (NoSuchElementException e) { + logger.warn("Invalid DeviceType:{}", deviceType, e); + } + return dType; } } } diff --git a/frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java b/frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java new file mode 100644 index 0000000000..4f5e8805f1 --- /dev/null +++ b/frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java @@ -0,0 +1,48 @@ +package org.pytorch.serve.archive.model; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import org.pytorch.serve.archive.utils.ArchiveUtils; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ModelConfigTest { + @Test + public void TestValidYamlConfig() throws InvalidModelException, IOException { + String yamlConfigFile = "src/test/resources/modelConfig/valid.yaml"; + ModelConfig modelConfig; + File configFile = new File(yamlConfigFile); + Map modelConfigMap = ArchiveUtils.readYamlFile(configFile); + modelConfig = ModelConfig.build(modelConfigMap); + + Assert.assertEquals(modelConfig.getMinWorkers(), 1); + Assert.assertEquals(modelConfig.getMaxWorkers(), 1); + Assert.assertEquals(modelConfig.getBatchSize(), 1); + Assert.assertEquals(modelConfig.getMaxBatchDelay(), 100); + Assert.assertEquals(modelConfig.getResponseTimeout(), 120); + Assert.assertEquals(modelConfig.getDeviceType(), ModelConfig.DeviceType.GPU); + Assert.assertEquals(modelConfig.getParallelLevel(), 4); + Assert.assertEquals(modelConfig.getParallelType(), ModelConfig.ParallelType.PP); + Assert.assertEquals(modelConfig.getDeviceIds().get(0).intValue(), 0); + } + + @Test + public void TestInvalidYamlConfig() throws InvalidModelException, IOException { + String yamlConfigFile = "src/test/resources/modelConfig/invalid.yaml"; + ModelConfig modelConfig; + File configFile = new File(yamlConfigFile); + Map modelConfigMap = ArchiveUtils.readYamlFile(configFile); + modelConfig = ModelConfig.build(modelConfigMap); + + Assert.assertNotEquals(modelConfig.getMinWorkers(), 1); + Assert.assertEquals(modelConfig.getMaxWorkers(), 1); + Assert.assertEquals(modelConfig.getBatchSize(), 1); + Assert.assertEquals(modelConfig.getMaxBatchDelay(), 100); + Assert.assertEquals(modelConfig.getResponseTimeout(), 120); + Assert.assertNotEquals(modelConfig.getDeviceType(), ModelConfig.DeviceType.GPU); + Assert.assertEquals(modelConfig.getParallelLevel(), 4); + Assert.assertNotEquals(modelConfig.getParallelType(), ModelConfig.ParallelType.PPTP); + Assert.assertNull(modelConfig.getDeviceIds()); + } +} diff --git a/frontend/archive/src/test/resources/modelConfig/invalid.yaml b/frontend/archive/src/test/resources/modelConfig/invalid.yaml new file mode 100644 index 0000000000..3544ef5ada --- /dev/null +++ b/frontend/archive/src/test/resources/modelConfig/invalid.yaml @@ -0,0 +1,10 @@ +# TS Frontend parameters +minWorkers: a +maxWorkers: 1 +batchSize: 1 +maxBatchDelay: 100 +responseTimeout: 120 +deviceType: xpu # cpu, gpu +deviceIds: 0,1,2,3] # device index for gpu +parallelLevel: 4 # rpc world size +parallelType: "xpp" # pp: pipeline parallel; pptp: tensor+pipeline parallel \ No newline at end of file diff --git a/frontend/archive/src/test/resources/modelConfig/valid.yaml b/frontend/archive/src/test/resources/modelConfig/valid.yaml new file mode 100644 index 0000000000..8980c1b6a4 --- /dev/null +++ b/frontend/archive/src/test/resources/modelConfig/valid.yaml @@ -0,0 +1,10 @@ +# TS Frontend parameters +minWorkers: 1 +maxWorkers: 1 +batchSize: 1 +maxBatchDelay: 100 +responseTimeout: 120 +deviceType: "gpu" # cpu, gpu +deviceIds: [0,1,2,3] # device index for gpu, neuron +parallelLevel: 4 # rpc world size +parallelType: "pp" # pp: pipeline parallel; pptp: tensor+pipeline parallel \ No newline at end of file diff --git a/frontend/archive/testng.xml b/frontend/archive/testng.xml index 16540e207d..0d050dfbcd 100644 --- a/frontend/archive/testng.xml +++ b/frontend/archive/testng.xml @@ -5,6 +5,7 @@ + From 7a60dd33810467d5c6ad0d18ac11c3803abe9974 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 25 Mar 2023 14:08:42 -0700 Subject: [PATCH 41/47] check deviceIds --- .../src/main/java/org/pytorch/serve/wlm/Model.java | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index 015a8055b1..9d721e5d8a 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -78,11 +78,13 @@ public Model(ModelArchive modelArchive, int queueSize) { : deviceType; } deviceIds = modelArchive.getModelConfig().getDeviceIds(); - for (Integer deviceId : deviceIds) { - if (deviceId < 0 || deviceId >= ConfigManager.getInstance().getNumberOfGpu()) { - logger.warn("Invalid deviceId:{}, ignore deviceIds list", deviceId); - deviceIds = null; - break; + if (deviceIds != null) { + for (Integer deviceId : deviceIds) { + if (deviceId < 0 || deviceId >= ConfigManager.getInstance().getNumberOfGpu()) { + logger.warn("Invalid deviceId:{}, ignore deviceIds list", deviceId); + deviceIds = null; + break; + } } } } else { From 615e91a84226b2aa866240e1f8692b7942fdba98 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 25 Mar 2023 16:43:19 -0700 Subject: [PATCH 42/47] set gpu >=1000 as invlid config --- .../serve/archive/model/ModelConfig.java | 12 ++++++++++ .../serve/archive/model/ModelConfigTest.java | 2 +- .../java/org/pytorch/serve/wlm/Model.java | 24 +++++++++++++------ .../pytorch/serve/wlm/WorkLoadManager.java | 3 +++ ts/service.py | 9 +++++-- 5 files changed, 40 insertions(+), 10 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index e82873db64..f086fe91b7 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -20,9 +20,11 @@ public class ModelConfig { private List deviceIds; private int parallelLevel = 1; private ParallelType parallelType = ParallelType.NONE; + private boolean deviceIdsValid; public static ModelConfig build(Map yamlMap) { ModelConfig modelConfig = new ModelConfig(); + modelConfig.deviceIdsValid = true; yamlMap.forEach( (k, v) -> { switch (k) { @@ -87,6 +89,7 @@ public static ModelConfig build(Map yamlMap) { if (v instanceof List) { modelConfig.setDeviceIds((List) v); } else { + modelConfig.deviceIdsValid = false; logger.warn("Invalid deviceIds: {}, should be list of integer", v); } break; @@ -169,6 +172,7 @@ public void setDeviceIds(List deviceIds) { } else { logger.warn("Invalid deviceIds:{},", deviceIds.get(i)); this.deviceIds = null; + this.deviceIdsValid = false; break; } } @@ -203,6 +207,14 @@ public DeviceType getDeviceType() { return deviceType; } + public boolean isDeviceIdsValid() { + return deviceIdsValid; + } + + public void setDeviceIdsValid(boolean deviceIdsValid) { + this.deviceIdsValid = deviceIdsValid; + } + public enum ParallelType { NONE(""), PP("pp"), diff --git a/frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java b/frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java index 4f5e8805f1..db986aab73 100644 --- a/frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java +++ b/frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java @@ -24,7 +24,7 @@ public void TestValidYamlConfig() throws InvalidModelException, IOException { Assert.assertEquals(modelConfig.getDeviceType(), ModelConfig.DeviceType.GPU); Assert.assertEquals(modelConfig.getParallelLevel(), 4); Assert.assertEquals(modelConfig.getParallelType(), ModelConfig.ParallelType.PP); - Assert.assertEquals(modelConfig.getDeviceIds().get(0).intValue(), 0); + Assert.assertEquals(modelConfig.getDeviceIds().get(2).intValue(), 2); } @Test diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index 9d721e5d8a..e30268d22a 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -52,6 +52,7 @@ public class Model { private int responseTimeout; private ModelVersionName modelVersionName; private AtomicInteger gpuCounter = new AtomicInteger(0); + private boolean isDeviceIdsValid = true; private boolean isWorkflowModel; @@ -77,13 +78,18 @@ public Model(ModelArchive modelArchive, int queueSize) { ? ModelConfig.DeviceType.GPU : deviceType; } - deviceIds = modelArchive.getModelConfig().getDeviceIds(); - if (deviceIds != null) { - for (Integer deviceId : deviceIds) { - if (deviceId < 0 || deviceId >= ConfigManager.getInstance().getNumberOfGpu()) { - logger.warn("Invalid deviceId:{}, ignore deviceIds list", deviceId); - deviceIds = null; - break; + isDeviceIdsValid = modelArchive.getModelConfig().isDeviceIdsValid(); + if (isDeviceIdsValid) { + deviceIds = modelArchive.getModelConfig().getDeviceIds(); + if (deviceIds != null) { + for (Integer deviceId : deviceIds) { + if (deviceId < 0 + || deviceId >= ConfigManager.getInstance().getNumberOfGpu()) { + logger.warn("Invalid deviceId:{}, ignore deviceIds list", deviceId); + deviceIds = null; + isDeviceIdsValid = false; + break; + } } } } @@ -334,4 +340,8 @@ public int getNumCores() { public AtomicInteger getGpuCounter() { return gpuCounter; } + + public boolean isDeviceIdsValid() { + return isDeviceIdsValid; + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index 27ea12f249..0344657e9f 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -208,6 +208,9 @@ private void addThreads( gpuCounter.accumulateAndGet( maxGpu, (prev, maxGpuId) -> ++prev % maxGpuId); } + if (!model.isDeviceIdsValid()) { + gpuId += 1000; + } } BatchAggregator aggregator = new BatchAggregator(model); diff --git a/ts/service.py b/ts/service.py index e95785ed7a..6796ee32ff 100644 --- a/ts/service.py +++ b/ts/service.py @@ -41,8 +41,13 @@ def __init__( ) if "deviceIds" in model_yaml_config and "parallelLevel" in model_yaml_config: - if int(model_yaml_config["parallelLevel"]) == 1: - gpu = model_yaml_config["deviceIds"][gpu] + if type(model_yaml_config["parallelLevel"]) is not int or \ + int(model_yaml_config["parallelLevel"]) <= 1: + # devicedIds is invalid in model config yaml file + if gpu >= 1000: + gpu = gpu % 1000 + else: + gpu = int(model_yaml_config["deviceIds"][gpu]) self._context = Context( model_name, From 8e7acad99e2120204e8e035f94d64d9634eb443f Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 25 Mar 2023 17:03:23 -0700 Subject: [PATCH 43/47] precommit fmt --- ts/service.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ts/service.py b/ts/service.py index 6796ee32ff..cbd03a086b 100644 --- a/ts/service.py +++ b/ts/service.py @@ -41,8 +41,10 @@ def __init__( ) if "deviceIds" in model_yaml_config and "parallelLevel" in model_yaml_config: - if type(model_yaml_config["parallelLevel"]) is not int or \ - int(model_yaml_config["parallelLevel"]) <= 1: + if ( + type(model_yaml_config["parallelLevel"]) is not int + or int(model_yaml_config["parallelLevel"]) <= 1 + ): # devicedIds is invalid in model config yaml file if gpu >= 1000: gpu = gpu % 1000 From 83605dd339c6fb5a7d3844456cab75a7bbe26c87 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 25 Mar 2023 21:19:33 -0700 Subject: [PATCH 44/47] optimize gpu assignment --- .../serve/archive/model/ModelConfig.java | 12 -------- .../java/org/pytorch/serve/wlm/Model.java | 28 +++++++++---------- .../pytorch/serve/wlm/WorkLoadManager.java | 17 ++++++----- ts/service.py | 26 ++++++++++------- 4 files changed, 37 insertions(+), 46 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index f086fe91b7..e82873db64 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -20,11 +20,9 @@ public class ModelConfig { private List deviceIds; private int parallelLevel = 1; private ParallelType parallelType = ParallelType.NONE; - private boolean deviceIdsValid; public static ModelConfig build(Map yamlMap) { ModelConfig modelConfig = new ModelConfig(); - modelConfig.deviceIdsValid = true; yamlMap.forEach( (k, v) -> { switch (k) { @@ -89,7 +87,6 @@ public static ModelConfig build(Map yamlMap) { if (v instanceof List) { modelConfig.setDeviceIds((List) v); } else { - modelConfig.deviceIdsValid = false; logger.warn("Invalid deviceIds: {}, should be list of integer", v); } break; @@ -172,7 +169,6 @@ public void setDeviceIds(List deviceIds) { } else { logger.warn("Invalid deviceIds:{},", deviceIds.get(i)); this.deviceIds = null; - this.deviceIdsValid = false; break; } } @@ -207,14 +203,6 @@ public DeviceType getDeviceType() { return deviceType; } - public boolean isDeviceIdsValid() { - return deviceIdsValid; - } - - public void setDeviceIdsValid(boolean deviceIdsValid) { - this.deviceIdsValid = deviceIdsValid; - } - public enum ParallelType { NONE(""), PP("pp"), diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index e30268d22a..2f3247df3b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -52,7 +52,7 @@ public class Model { private int responseTimeout; private ModelVersionName modelVersionName; private AtomicInteger gpuCounter = new AtomicInteger(0); - private boolean isDeviceIdsValid = true; + private boolean hasDeviceIds; private boolean isWorkflowModel; @@ -78,18 +78,16 @@ public Model(ModelArchive modelArchive, int queueSize) { ? ModelConfig.DeviceType.GPU : deviceType; } - isDeviceIdsValid = modelArchive.getModelConfig().isDeviceIdsValid(); - if (isDeviceIdsValid) { - deviceIds = modelArchive.getModelConfig().getDeviceIds(); - if (deviceIds != null) { - for (Integer deviceId : deviceIds) { - if (deviceId < 0 - || deviceId >= ConfigManager.getInstance().getNumberOfGpu()) { - logger.warn("Invalid deviceId:{}, ignore deviceIds list", deviceId); - deviceIds = null; - isDeviceIdsValid = false; - break; - } + + deviceIds = modelArchive.getModelConfig().getDeviceIds(); + if (deviceIds != null && deviceIds.size() > 0) { + hasDeviceIds = true; + for (Integer deviceId : deviceIds) { + if (deviceId < 0 || deviceId >= ConfigManager.getInstance().getNumberOfGpu()) { + logger.warn("Invalid deviceId:{}, ignore deviceIds list", deviceId); + deviceIds = null; + hasDeviceIds = false; + break; } } } @@ -341,7 +339,7 @@ public AtomicInteger getGpuCounter() { return gpuCounter; } - public boolean isDeviceIdsValid() { - return isDeviceIdsValid; + public boolean isHasDeviceIds() { + return hasDeviceIds; } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index 0344657e9f..15a762ea05 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -196,21 +196,20 @@ private void addThreads( int gpuId = -1; if (maxGpu > 0) { - if (model.getDeviceIds() != null && model.getDeviceIds().size() > 0) { + if (model.isHasDeviceIds()) { gpuId = - model.getGpuCounter() - .getAndAccumulate( - maxGpu, - (prev, maxGpuId) -> - (prev + model.getParallelLevel()) % maxGpuId); + 1000 + + model.getGpuCounter() + .getAndAccumulate( + maxGpu, + (prev, maxGpuId) -> + (prev + model.getParallelLevel()) + % maxGpuId); } else { gpuId = gpuCounter.accumulateAndGet( maxGpu, (prev, maxGpuId) -> ++prev % maxGpuId); } - if (!model.isDeviceIdsValid()) { - gpuId += 1000; - } } BatchAggregator aggregator = new BatchAggregator(model); diff --git a/ts/service.py b/ts/service.py index cbd03a086b..a5214fb9f2 100644 --- a/ts/service.py +++ b/ts/service.py @@ -40,16 +40,22 @@ def __init__( os.path.join(model_dir, model_yaml_config_file) ) - if "deviceIds" in model_yaml_config and "parallelLevel" in model_yaml_config: - if ( - type(model_yaml_config["parallelLevel"]) is not int - or int(model_yaml_config["parallelLevel"]) <= 1 - ): - # devicedIds is invalid in model config yaml file - if gpu >= 1000: - gpu = gpu % 1000 - else: - gpu = int(model_yaml_config["deviceIds"][gpu]) + parallelLevel = 1 + if ( + "parallelLevel" in model_yaml_config + and type(model_yaml_config["parallelLevel"]) is int + and int(model_yaml_config["parallelLevel"]) > 1 + ): + parallelLevel = int(model_yaml_config["parallelLevel"]) + + # devicedIds in model config yaml file + if type(gpu) is int and gpu >= 1000: + if parallelLevel == 1: + gpu = int(model_yaml_config["deviceIds"][gpu % 1000]) + else: + gpu = gpu % 1000 + elif "deviceIds" in model_yaml_config: + del model_yaml_config["deviceIds"] self._context = Context( model_name, From 8e4a1c562e2bb832d42c3ac6f5fd86ca63a72c36 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 28 Mar 2023 18:07:11 -0700 Subject: [PATCH 45/47] precommit fmt --- .../tests/unit_tests/test_model_packaging_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py b/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py index 645916da2b..802eedfa6d 100644 --- a/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py +++ b/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py @@ -11,6 +11,7 @@ MANIFEST_FILE = Path(__file__).parents[1].joinpath("integ_tests/MAR-INF/MANIFEST.json") + # noinspection PyClassHasNoInit def _validate_mar(patches): if platform.system() == "Windows": @@ -254,6 +255,7 @@ def test_with_return_true(self): is True ) + def create_manifest_from_test_json(test_json): test_ = {k.replace("-", "_"): v for k, v in test_json.items()} test_["requirements_file"] = "" From 628f58bfe6d9def240c3bcf15aa60980f1412f47 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 28 Mar 2023 20:21:32 -0700 Subject: [PATCH 46/47] fix sanity test error caused by auto merge --- .../tests/unit_tests/test_model_packaging_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py b/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py index 802eedfa6d..19f7e447ea 100644 --- a/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py +++ b/model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py @@ -260,6 +260,7 @@ def create_manifest_from_test_json(test_json): test_ = {k.replace("-", "_"): v for k, v in test_json.items()} test_["requirements_file"] = "" test_["runtime"] = RuntimeType.PYTHON3.value + test_["config_file"] = "" args = namedtuple("Model", test_.keys())(**test_) manifest = ModelExportUtils.generate_manifest_json(args) @@ -281,6 +282,7 @@ def test_archive_creation_with_zip_store(tmp_path, integ_tests): "serialized-file", "handler", "extra-files", + "config-file", ) for k in keys: From 08b6eabd82f239806e57138a1075763e68723ec9 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 28 Mar 2023 21:37:38 -0700 Subject: [PATCH 47/47] fix test_archive_creation_with_zip_store --- .../model_archiver/tests/integ_tests/configuration.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model-archiver/model_archiver/tests/integ_tests/configuration.json b/model-archiver/model_archiver/tests/integ_tests/configuration.json index a809a49c52..528773fb99 100644 --- a/model-archiver/model_archiver/tests/integ_tests/configuration.json +++ b/model-archiver/model_archiver/tests/integ_tests/configuration.json @@ -71,7 +71,8 @@ "archive-format": "zip-store", "iterations": 2, "version": "1.0", - "force": true + "force": true, + "config-file": "" }, { "name": "packaging_mar_with_handler_name",