From c98c55a5873feed6c99ff7f3773e507288f6cb5b Mon Sep 17 00:00:00 2001 From: jiong-zhang Date: Fri, 5 Aug 2022 02:16:00 +0000 Subject: [PATCH] Add README to distributed XR-Transformer --- pecos/distributed/xmc/xtransformer/README.md | 112 +++++++ pecos/distributed/xmc/xtransformer/model.py | 15 +- ...ormer cookbook and Distributed PECOS.ipynb | 310 ++++++++++-------- 3 files changed, 295 insertions(+), 142 deletions(-) create mode 100644 pecos/distributed/xmc/xtransformer/README.md diff --git a/pecos/distributed/xmc/xtransformer/README.md b/pecos/distributed/xmc/xtransformer/README.md new file mode 100644 index 0000000..b28e4ba --- /dev/null +++ b/pecos/distributed/xmc/xtransformer/README.md @@ -0,0 +1,112 @@ +# Distributed PECOS eXtreme Multi-label Classification: XR-Transformer + +`pecos.distributed.xmc.xtransformer` enables distributed fine-tuning for PECOS XR-Transformer model ([`pecos.xmc.xtransformer`](../../../xmc/xtransformer/README.md)). + +Note that this module only supports fine-tuning of XR-Transformer encoders, not hierarchical labal tree building or linear ranker training. + +## Prerequisites + +### Hardware + +You need the following hardware to train distributed PECOS: + +* Cluster of machines connected by network which can password-less SSH to each other. + * IP address of every machine in the cluster is known. +* Shared network disk mounted on all machines. + * For accessing data and saving trained models. + +Currently we do not provide guides to setup a cluster but we might provide in the future. For now, please refer to your organization's hardware management for help. + + +### Software + +Install the following software on **every** machine of your cluster: + +#### Install PECOS +Please follow the [main guide for PECOS installation](https://github.com/amzn/pecos#requirements-and-installation). + +#### Install DeepSpeed + +```bash +DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 pip3 install deepspeed +``` + +### Workspace Setup +Make a workspace directory on your shared network disk: + +```bash +cd +mkdir pecos-workspace && cd pecos-workspace +``` +Create a `hostfile` consisting of all cluster machines' IP addresses and number of GPUs on each line: +```bash +cat << EOF > hostfile + slots= + slots= +... + slots= +EOF +``` +Test cluster connectivity: +```bash + deepspeed --hostfile machinefile --module pecos.distributed.diagnostic_tools.deepspeed_comm --timeout 60 --shared-workdir . +``` + +## Getting started + +### Basic Command-line Usage + +The distributed training CLI `pecos.distributed.xmc.xtransformer.train` is similar to that of `pecos.xmc.xtransformer`. + +There are several additional things to note: + +* **Have the Hierarchical Label Tree (HLT) ready**: The distributed training module will not automatically construct HLT for you. So you need to supply existing label clustering with `--code-path`, otherwise the module will fall back to One-Versus-All fine-tuning (not applicable for large label space). +* **Do not accept instance numerical feature:** The training of sparse+dense concat linear models are disabled. + +You can generate a `.json` file with all of the parameters that you can edit and fill in. +```bash + > python3 -m pecos.distributed.xmc.xtransformer.train --generate-params-skeleton &> params.json +``` + +After editing the `params.json` file, you can do training via: + +```bash +python3 -m pecos.distributed.xmc.xtransformer.train \ + --trn-text-path ${T_path} \ + --trn-label-path ${Y_path} \ + --code-path ${C_path} \ + --model-dir ${model_dir} \ + --params-path params.json +``` +where +* `T_path` is the path to the input text file of the training instances. Text file with `N` lines where each line is the text feature of the corresponding training instance. +* `Y_path` is the path to the CSR npz file of the training label matrices with shape `(N, L)`. +* `C_path` is the path to the CSC npz file of the clustering matrix with shape `(N, K)`, where `K` is the number of leaf clusters. +* `model_dir` is the path to the model folder where the trained model will be saved to, will be created if not exist. + +After fine-tuning, you can generate the instance embeddings via: +```bash +deepspeed --hostfile machinefile --module pecos.distributed.xmc.xtransformer.encode -t ${T_path} -m ${model_dir} -o ${result_dir} +``` +where `result_dir` is the folder (under your shared network disk) in which the embeddings will be written. +To handle large data, the embeddings will be written in shards to `${result_dir}/X.emb.0.npy`, `${result_dir}/X.emb.1.npy` ... `${result_dir}/X.emb.[WORLD_SIZE].npy`. + +For small data, you can also use the single node XR-Transformer module +``` +python3 -m pecos.xmc.xtransformer.encode -t ${T_path} -m ${model_dir} -o ${result_path} +``` +*** + +Copyright (2021) Amazon.com, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/pecos/distributed/xmc/xtransformer/model.py b/pecos/distributed/xmc/xtransformer/model.py index 79b939d..5e8d071 100644 --- a/pecos/distributed/xmc/xtransformer/model.py +++ b/pecos/distributed/xmc/xtransformer/model.py @@ -255,16 +255,17 @@ def train( parent_model.C = cur_prob.C parent_model.train_params = cur_train_params parent_model.pred_params = cur_pred_params - if cur_train_params.bootstrap_method == "inherit" and i > 0: - parent_model.text_model.inherit(prev_head, cur_prob.C, sparse=False) - LOGGER.info("Initialized transformer text_model from parent layer!") - elif cur_train_params.bootstrap_method == "no-bootstrap" or i == 0: + + if cur_train_params.bootstrap_method == "no-bootstrap" or i == 0: parent_model.text_model.random_init(sparse=False) LOGGER.info("Randomly initialized transformer text_model!") else: - raise ValueError( - f"bootstrap_method={cur_train_params.bootstrap_method} not supported in distributed training!" - ) + if cur_train_params.bootstrap_method != "inherit": + LOGGER.warning( + f"bootstrap_method={cur_train_params.bootstrap_method} not supported in distributed training. Fall back to inherit" + ) + parent_model.text_model.inherit(prev_head, cur_prob.C, sparse=False) + LOGGER.info("Initialized transformer text_model from parent layer!") if cur_train_params.pre_tokenize: if not prob.is_tokenized: diff --git a/tutorials/kdd22/Session 5 XR-Transformer cookbook and Distributed PECOS.ipynb b/tutorials/kdd22/Session 5 XR-Transformer cookbook and Distributed PECOS.ipynb index 871fda8..12e1353 100644 --- a/tutorials/kdd22/Session 5 XR-Transformer cookbook and Distributed PECOS.ipynb +++ b/tutorials/kdd22/Session 5 XR-Transformer cookbook and Distributed PECOS.ipynb @@ -138,11 +138,9 @@ "text": [ "Training Parameters of Hierarchical K-means.\n", "\n", - " nr_splits (int, optional): The out-degree of each internal node of the tree. Ignored if `imbalanced_ratio != 0` because imbalanced clustering supports only 2-means. Default is `16`.\n", + " nr_splits (int, optional): The out-degree of each internal node of the tree. Default is `16`.\n", " min_codes (int): The number of direct child nodes that the top level of the hierarchy should have.\n", " max_leaf_size (int, optional): The maximum size of each leaf node of the tree. Default is `100`.\n", - " imbalanced_ratio (float, optional): Value between `0.0` and `0.5` (inclusive). Indicates how relaxed the balancedness constraint of 2-means can be. Specifically, if an iteration of 2-means is clustering `L` labels, the size of the output 2 clusters will be within approx `imbalanced_ratio * 2 * L` of each other. Default is `0.0`.\n", - " imbalanced_depth (int, optional): Maximum depth of imbalanced clustering. After depth `imbalanced_depth` is reached, balanced clustering will be used. Default is `100`.\n", " spherical (bool, optional): True will l2-normalize the centroids of k-means after each iteration. Default is `True`.\n", " seed (int, optional): Random seed. Default is `0`.\n", " kmeans_max_iter (int, optional): Maximum number of iterations for each k-means problem. Default is `20`.\n", @@ -160,7 +158,7 @@ "id": "18c61c61-20ad-4a4f-93a3-85a736215310", "metadata": {}, "source": [ - "Here is an example of the parameters related to label hierarchy in `eurlex-4k` model:" + "Here is an example of the parameters related to label hierarchy in `wiki10-31k` model:" ] }, { @@ -178,12 +176,10 @@ " \"class_fullname\": \"pecos.xmc.base###HierarchicalKMeans.TrainParams\"\n", " },\n", " \"nr_splits\": 16,\n", - " \"min_codes\": 16,\n", + " \"min_codes\": 128,\n", " \"max_leaf_size\": 16,\n", - " \"imbalanced_ratio\": 0.0,\n", - " \"imbalanced_depth\": 100,\n", " \"spherical\": true,\n", - " \"seed\": 0,\n", + " \"seed\": 10001,\n", " \"kmeans_max_iter\": 20,\n", " \"threads\": -1\n", "}\n" @@ -198,13 +194,13 @@ "from pecos.utils import smat_util\n", "from pecos.xmc import Indexer, LabelEmbeddingFactory\n", "\n", - "param_url = \"https://raw.githubusercontent.com/amzn/pecos/mainline/examples/xr-transformer-neurips21/params/eurlex-4k/bert/params.json\"\n", + "param_url = \"https://raw.githubusercontent.com/amzn/pecos/mainline/examples/xr-transformer-neurips21/params/wiki10-31k/bert/params.json\"\n", "params = json.loads(requests.get(param_url).text)\n", " \n", - "eurlex4k_train_params = XTransformer.TrainParams.from_dict(params[\"train_params\"])\n", - "eurlex4k_pred_params = XTransformer.PredParams.from_dict(params[\"pred_params\"])\n", + "wiki31k_train_params = XTransformer.TrainParams.from_dict(params[\"train_params\"])\n", + "wiki31k_pred_params = XTransformer.PredParams.from_dict(params[\"pred_params\"])\n", "\n", - "print(json.dumps(eurlex4k_train_params.preliminary_indexer_params.to_dict(), indent=True))" + "print(json.dumps(wiki31k_train_params.preliminary_indexer_params.to_dict(), indent=True))" ] }, { @@ -217,20 +213,20 @@ "name": "stdout", "output_type": "stream", "text": [ - "Preliminary HLT structure [16, 256, 3956]\n" + "Preliminary HLT structure [128, 2048, 30938]\n" ] } ], "source": [ - "X_feat = smat_util.load_matrix(\"work_dir/xmc-base/eurlex-4k/X.trn.npz\", dtype=np.float32)\n", - "Y = smat_util.load_matrix(\"work_dir/xmc-base/eurlex-4k/Y.trn.npz\", dtype=np.float32)\n", + "X_feat = smat_util.load_matrix(\"xmc-base/wiki10-31k/tfidf-attnxml/X.trn.npz\", dtype=np.float32)\n", + "Y = smat_util.load_matrix(\"xmc-base/wiki10-31k/Y.trn.npz\", dtype=np.float32)\n", "\n", - "with open(\"work_dir/xmc-base/eurlex-4k/X.trn.txt\", 'r') as fin:\n", + "with open(\"xmc-base/wiki10-31k/X.trn.txt\", 'r') as fin:\n", " X_txt = [xx.strip() for xx in fin.readlines()]\n", "\n", "preliminary_hlt = Indexer.gen(\n", " LabelEmbeddingFactory.create(Y, X_feat, method=\"pifa\"),\n", - " train_params=eurlex4k_train_params.preliminary_indexer_params,\n", + " train_params=wiki31k_train_params.preliminary_indexer_params,\n", ")\n", "\n", "print(f\"Preliminary HLT structure {[c.shape[0] for c in preliminary_hlt]}\")" @@ -241,7 +237,7 @@ "id": "11232266-fceb-464b-8b71-ceaf723d7b39", "metadata": {}, "source": [ - "In this case the preliminiary HLT has 3 levels (16-256-3956) and the refined HLT has 4 levels ( 4-32-256-3956).\n", + "In this case the preliminiary HLT has 3 levels (128-2048-30938).\n", "As we choose the `max_match_clusters` to be `32768`, the fine-tuning will happen on all 3 levels of preliminary HLT.\n", "\n", "The preliminary HLT is usually constructed such that:\n", @@ -335,7 +331,7 @@ "id": "38cc7480-d6a7-4c81-96cf-cb35aa823fc2", "metadata": {}, "source": [ - "For the `eurlex-4k` model, we are fine-tuning the `bert-base-uncased` pre-trained model at 3 levels of the preliminary HLT:" + "For the `wiki10-31k` model, we are fine-tuning the `bert-base-uncased` pre-trained model at 3 levels of the preliminary HLT:" ] }, { @@ -359,7 +355,7 @@ " \"bootstrap_method\": \"weighted-linear\",\n", " \"cache_dir\": \"\",\n", " \"checkpoint_dir\": \"\",\n", - " \"cost_sensitive_ranker\": false,\n", + " \"cost_sensitive_ranker\": true,\n", " \"eval_by_true_shorlist\": false,\n", " \"gradient_accumulation_steps\": 1,\n", " \"hidden_dropout_prob\": 0.1,\n", @@ -372,12 +368,12 @@ " \"max_grad_norm\": 1.0,\n", " \"max_no_improve_cnt\": -1,\n", " \"max_num_labels_in_gpu\": 65536,\n", - " \"max_steps\": 600,\n", + " \"max_steps\": 1000,\n", " \"model_shortcut\": \"bert-base-uncased\",\n", " \"negative_sampling\": \"tfn+man\",\n", " \"num_train_epochs\": 10,\n", - " \"pre_tensorize_labels\": false,\n", - " \"pre_tokenize\": false,\n", + " \"pre_tensorize_labels\": true,\n", + " \"pre_tokenize\": true,\n", " \"save_steps\": 200,\n", " \"threshold\": 0.001,\n", " \"use_gpu\": true,\n", @@ -388,8 +384,8 @@ } ], "source": [ - "print(\"=\"*10, f\"matcher_params_chain[0] (len={len(eurlex4k_train_params.matcher_params_chain)})\", \"=\"*10)\n", - "print(json.dumps(eurlex4k_train_params.matcher_params_chain[0].to_dict(), sort_keys=True, indent=True))" + "print(\"=\"*10, f\"matcher_params_chain[0] (len={len(wiki31k_train_params.matcher_params_chain)})\", \"=\"*10)\n", + "print(json.dumps(wiki31k_train_params.matcher_params_chain[0].to_dict(), sort_keys=True, indent=True))" ] }, { @@ -413,17 +409,40 @@ "\n", "There are two ways to provide pre-trained Transformer encoder:\n", "* **Download from huggingface repo** (https://huggingface.co/models): model name provided by `model_shortcut`. (e.x. `bert-base-uncased` or `w11wo/javanese-distilbert-small`)\n", - "* **Load from local disk**: model path provided by `init_model_dir`. Model should be loadable through `TransformerMatcher.load()`\n", + "* **Load your custom model from local disk**: model path provided by `init_model_dir`. Model should be loadable through `TransformerMatcher.load()`\n", "\n", "Note that both `model_shortcut` and `init_model_dir` will only be used in the first fine-tuning layer, as the later ones will just continue on the final state from parent encoder.\n" ] }, + { + "cell_type": "markdown", + "id": "f6e0330c-c960-4cdc-bc43-30d0bf256206", + "metadata": {}, + "source": [ + "A simple example if you want to construct your custom pre-trained model for XR-Transformer fine-tuning:" + ] + }, { "cell_type": "code", "execution_count": 7, "id": "9570261c-10b4-4c74-b018-0bd3c8b524d7", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']\n", + "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Some weights of the model checkpoint at work_dir/my_pre_trained_model/text_encoder were not used when initializing DistilBertForXMC: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias']\n", + "- This IS expected if you are initializing DistilBertForXMC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForXMC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "08/05/2022 01:43:29 - WARNING - pecos.xmc.xtransformer.matcher - XMC text_model of DistilBertForXMC not initialized from pre-trained model.\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -439,13 +458,13 @@ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", "\n", "init_model_dir = \"work_dir/my_pre_trained_model\"\n", - "os.makedirs(init_model_dir, exist_ok=True)\n", + "os.makedirs(\"work_dir\", exist_ok=True)\n", "\n", "# example to use your own pre-trained model, here we use huggingface model as an example\n", - "my_tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n", + "my_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", "my_encoder = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased\")\n", "\n", - "# do my own pre-training/tuning/etc\n", + "# do my own modification/tuning/etc\n", "# ...\n", "\n", "# save my own model to disk\n", @@ -458,6 +477,14 @@ "print(f\"{matcher.__class__} model loaded with encoder_type={matcher.model_type} num_labels={matcher.nr_labels}\")" ] }, + { + "cell_type": "markdown", + "id": "257dcefe-9839-46bb-806a-cc83932dadfb", + "metadata": {}, + "source": [ + "Or you could download our released encoders via:" + ] + }, { "cell_type": "code", "execution_count": 8, @@ -466,10 +493,10 @@ "outputs": [], "source": [ "%%bash\n", - "DATASET=\"eurlex-4k\"\n", - "wget -q https://archive.org/download/xr-transformer-encoders/${DATASET}.tar.gz\n", + "DATASET=\"wiki10-31k\"\n", + "wget -q https://archive.org/download/xr-transformer-encoders/${DATASET}.tar.gz -O ${DATASET}_encoder.tar.gz\n", "mkdir -p ./work_dir/xr-transformer-encoder\n", - "tar -zxf ./${DATASET}.tar.gz -C ./work_dir/xr-transformer-encoder" + "tar -zxf ./${DATASET}_encoder.tar.gz -C ./work_dir/xr-transformer-encoder" ] }, { @@ -482,12 +509,12 @@ "name": "stdout", "output_type": "stream", "text": [ - " model loaded with encoder_type=bert num_labels=3956\n" + " model loaded with encoder_type=bert num_labels=30938\n" ] } ], "source": [ - "matcher = TransformerMatcher.load(\"./work_dir/xr-transformer-encoder/eurlex-4k/bert/text_encoder\")\n", + "matcher = TransformerMatcher.load(\"./work_dir/xr-transformer-encoder/wiki10-31k/bert/text_encoder\")\n", "print(f\"{matcher.__class__} model loaded with encoder_type={matcher.model_type} num_labels={matcher.nr_labels}\")" ] }, @@ -498,10 +525,10 @@ "source": [ "#### 2.2.2 Bootstrapping and Cost Sensitive Leanring\n", "\n", - "We provide three options to boostrap the XMC head at child level (i.e. W^(t+1)) from parent level (i.e. W^(t)):\n", - "* `bootstrap_method=None`: No bootstrap, W^(t+1) will be randomly initialized.\n", + "We provide three options to boostrap the XMC head at child level (i.e. $W^{(t+1)}$) from parent level (i.e. $W^{(t)}$):\n", + "* `bootstrap_method=None`: No bootstrap, $W^{(t+1)}$ will be randomly initialized.\n", "* `bootstrap_method='inherit'`: Bootstrap by inherit the weight vector from parent node. \n", - "* `bootstrap_method='linear'`(default): linear model will be trained on final embeddings of parent layer and be used as initial point for W^(t+1).\n", + "* `bootstrap_method='linear'`(default): linear model will be trained on final embeddings of parent layer and be used as initial point for $W^{(t+1)}$.\n", "\n", "In most cases the default linear bootstrapper would give good enough initial point the XMC heads.\n", "Compared with linear bootstrapper, the inherit bootstrapper has less memory/time overhead. " @@ -556,12 +583,21 @@ "id": "f8531b6d-f96f-476e-836e-01a2f6daa35f", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForXMC: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n", + "- This IS expected if you are initializing BertForXMC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertForXMC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "prec = 84.97 78.05 71.25 64.93 58.97 53.42 48.24 43.70 39.92 36.81\n", - "recall = 17.26 31.35 42.42 51.00 57.39 62.01 65.08 67.23 68.95 70.56\n" + "prec = 85.22 82.55 77.26 72.15 67.42 63.13 59.33 56.08 53.02 50.24\n", + "recall = 5.05 9.76 13.58 16.74 19.41 21.68 23.64 25.41 26.92 28.22\n" ] } ], @@ -570,19 +606,20 @@ "prob = MLProblemWithText(X_txt, Y, X_feat=X_feat)\n", "\n", "# disable fine-tuning, use pre-trained bert model from huggingface\n", - "eurlex4k_train_params.do_fine_tune = False\n", + "wiki31k_train_params.do_fine_tune = False\n", "\n", + "# this will be slow on CPU only machine\n", "xtf_pretrained = XTransformer.train(\n", " prob,\n", " clustering=preliminary_hlt,\n", - " train_params=eurlex4k_train_params,\n", - " pred_params=eurlex4k_pred_params,\n", + " train_params=wiki31k_train_params,\n", + " pred_params=wiki31k_pred_params,\n", ")\n", "\n", - "X_feat_tst = smat_util.load_matrix(\"work_dir/xmc-base/eurlex-4k/X.tst.npz\", dtype=np.float32)\n", - "Y_tst = smat_util.load_matrix(\"work_dir/xmc-base/eurlex-4k/Y.tst.npz\", dtype=np.float32)\n", + "X_feat_tst = smat_util.load_matrix(\"xmc-base/wiki10-31k/tfidf-attnxml/X.tst.npz\", dtype=np.float32)\n", + "Y_tst = smat_util.load_matrix(\"xmc-base/wiki10-31k/Y.tst.npz\", dtype=np.float32)\n", "\n", - "with open(\"work_dir/xmc-base/eurlex-4k/X.tst.txt\", 'r') as fin:\n", + "with open(\"xmc-base/wiki10-31k/X.tst.txt\", 'r') as fin:\n", " X_txt_tst = [xx.strip() for xx in fin.readlines()]\n", "\n", "P_pretrained = xtf_pretrained.predict(X_txt_tst, X_feat=X_feat_tst)\n", @@ -600,20 +637,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "prec = 87.17 80.98 74.50 68.13 61.61 55.58 50.18 45.46 41.51 38.12\n", - "recall = 17.72 32.60 44.41 53.52 59.92 64.44 67.58 69.79 71.55 72.89\n" + "prec = 87.95 83.54 78.79 73.95 69.43 65.14 61.08 57.70 54.63 51.97\n", + "recall = 5.25 9.89 13.84 17.14 19.99 22.36 24.35 26.16 27.73 29.21\n" ] } ], "source": [ "# use fine-tuned bert model\n", - "eurlex4k_train_params.matcher_params_chain[0].init_model_dir = \"./work_dir/xr-transformer-encoder/eurlex-4k/bert/text_encoder\"\n", + "wiki31k_train_params.matcher_params_chain[0].init_model_dir = \"./work_dir/xr-transformer-encoder/wiki10-31k/bert/text_encoder\"\n", "\n", + "# this will be slow on CPU only machine\n", "xtf_fine_tuned = XTransformer.train(\n", " prob,\n", " clustering=preliminary_hlt,\n", - " train_params=eurlex4k_train_params,\n", - " pred_params=eurlex4k_pred_params,\n", + " train_params=wiki31k_train_params,\n", + " pred_params=wiki31k_pred_params,\n", ")\n", "\n", "P_fine_tuned = xtf_fine_tuned.predict(X_txt_tst, X_feat=X_feat_tst)\n", @@ -662,7 +700,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "id": "2834997c-4bd9-4732-acb6-f7e4e7f6a090", "metadata": {}, "outputs": [ @@ -684,14 +722,30 @@ "output_type": "stream", "text": [ "/opt/amazon/openmpi/bin/mpiexec\n", - "Hello, World! I am process 3 of 8 on ip-172-31-8-94.ec2.internal.\n", - "Hello, World! I am process 0 of 8 on ip-172-31-8-94.ec2.internal.\n", - "Hello, World! I am process 1 of 8 on ip-172-31-8-94.ec2.internal.\n", - "Hello, World! I am process 2 of 8 on ip-172-31-8-94.ec2.internal.\n", - "Hello, World! I am process 4 of 8 on ip-172-31-8-94.ec2.internal.\n", - "Hello, World! I am process 5 of 8 on ip-172-31-8-94.ec2.internal.\n", - "Hello, World! I am process 6 of 8 on ip-172-31-8-94.ec2.internal.\n", - "Hello, World! I am process 7 of 8 on ip-172-31-8-94.ec2.internal.\n" + "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", + "Requirement already satisfied: mpi4py in /home/ec2-user/repo/tutorial-env/lib/python3.9/site-packages (3.1.3)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: You are using pip version 20.2.3; however, version 22.2.2 is available.\n", + "You should consider upgrading via the '/home/ec2-user/repo/tutorial-env/bin/python3 -m pip install --upgrade pip' command.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Hello, World! I am process 0 of 8 on ip-[MASKED].ec2.internal.\n", + "Hello, World! I am process 1 of 8 on ip-[MASKED].ec2.internal.\n", + "Hello, World! I am process 2 of 8 on ip-[MASKED].ec2.internal.\n", + "Hello, World! I am process 3 of 8 on ip-[MASKED].ec2.internal.\n", + "Hello, World! I am process 4 of 8 on ip-[MASKED].ec2.internal.\n", + "Hello, World! I am process 5 of 8 on ip-[MASKED].ec2.internal.\n", + "Hello, World! I am process 6 of 8 on ip-[MASKED].ec2.internal.\n", + "Hello, World! I am process 7 of 8 on ip-[MASKED].ec2.internal.\n" ] } ], @@ -716,13 +770,9 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 14, "id": "f3b7e689-480d-4c81-a408-d4b840695880", "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - }, "tags": [] }, "outputs": [ @@ -730,69 +780,59 @@ "name": "stderr", "output_type": "stream", "text": [ - "07/07/2022 22:33:34 - INFO - pecos.utils.profile_util - psutil module installed, will print memory info.\n", - "07/07/2022 22:33:34 - INFO - __main__ - Started loading data on Rank 1 ... RSS 91.0 MB. Full mem info: pmem(rss=95412224, vms=35591151616, shared=47382528, text=2732032, lib=0, data=166395904, dirty=0)\n", - "07/07/2022 22:33:34 - INFO - pecos.utils.profile_util - psutil module installed, will print memory info.\n", - "07/07/2022 22:33:34 - INFO - __main__ - Started loading data on Rank 0 ... RSS 91.2 MB. Full mem info: pmem(rss=95682560, vms=35591151616, shared=47669248, text=2732032, lib=0, data=166395904, dirty=0)\n", - "07/07/2022 22:33:34 - INFO - __main__ - Done loading data on Rank 0. RSS 126.0 MB. Full mem info: pmem(rss=132136960, vms=35626360832, shared=47800320, text=2732032, lib=0, data=201605120, dirty=0)\n", - "07/07/2022 22:33:34 - INFO - pecos.distributed.xmc.base - Starts creating label embedding PIFA for meta tree on Rank 0 node... RSS 126.0 MB. Full mem info: pmem(rss=132136960, vms=35626360832, shared=47800320, text=2732032, lib=0, data=201605120, dirty=0)\n", - "07/07/2022 22:33:34 - INFO - __main__ - Done loading data on Rank 1. RSS 125.8 MB. Full mem info: pmem(rss=131928064, vms=35626360832, shared=47579136, text=2732032, lib=0, data=201605120, dirty=0)\n", - "07/07/2022 22:33:34 - INFO - pecos.distributed.xmc.base - Done creating label embedding PIFA for meta tree on Rank 0 node. RSS 198.7 MB. Full mem info: pmem(rss=208375808, vms=35776598016, shared=48451584, text=2732032, lib=0, data=285478912, dirty=0)\n", - "07/07/2022 22:33:34 - INFO - pecos.distributed.xmc.base - Starts generating meta tree cluster on main node...\n", - "07/07/2022 22:33:34 - INFO - pecos.distributed.xmc.base - Determined meta-tree leaf clusters number: 4. 2 nodes will train 4 sub-trees. Number of data labels: 3956, nr_splits: 16\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Done generating meta tree cluster. RSS 225.4 MB. Full mem info: pmem(rss=236306432, vms=35804254208, shared=48713728, text=2732032, lib=0, data=313135104, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Rank 0 get 2 sub-tree assignments.\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Rank 1 get 2 sub-tree assignments.\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - On rank 0, 0th sub-tree assignment has 989 labels: [0, 1, 2, 4, 5, 6, 7, 8, 9, 11]...\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - On rank 1, 0th sub-tree assignment has 989 labels: [18, 22, 31, 32, 35, 37, 38, 39, 57, 62]...\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Starts creating label embedding PIFA for 0th sub-tree on rank 0... RSS 155.5 MB. Full mem info: pmem(rss=163016704, vms=35730997248, shared=48934912, text=2732032, lib=0, data=239878144, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Starts creating label embedding PIFA for 0th sub-tree on rank 1... RSS 125.8 MB. Full mem info: pmem(rss=131928064, vms=35626622976, shared=47579136, text=2732032, lib=0, data=201867264, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Done creating label embedding PIFA for 0th sub-tree on rank 0. RSS 160.6 MB. Full mem info: pmem(rss=168361984, vms=35734740992, shared=49242112, text=2732032, lib=0, data=245108736, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Starts generating 0th sub-tree cluster on rank 0...\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Done creating label embedding PIFA for 0th sub-tree on rank 1. RSS 148.8 MB. Full mem info: pmem(rss=156049408, vms=35724161024, shared=48783360, text=2732032, lib=0, data=233041920, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Starts generating 0th sub-tree cluster on rank 1...\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Done generating 0th sub-tree cluster on rank 0. RSS 160.8 MB. Full mem info: pmem(rss=168628224, vms=35734740992, shared=49242112, text=2732032, lib=0, data=245108736, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - On rank 0, 1th sub-tree assignment has 989 labels: [3, 10, 12, 13, 17, 21, 26, 33, 34, 36]...\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Starts creating label embedding PIFA for 1th sub-tree on rank 0... RSS 160.8 MB. Full mem info: pmem(rss=168628224, vms=35734740992, shared=49242112, text=2732032, lib=0, data=245108736, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Done creating label embedding PIFA for 1th sub-tree on rank 0. RSS 179.4 MB. Full mem info: pmem(rss=188153856, vms=35754147840, shared=49242112, text=2732032, lib=0, data=264515584, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Starts generating 1th sub-tree cluster on rank 0...\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Done generating 0th sub-tree cluster on rank 1. RSS 159.7 MB. Full mem info: pmem(rss=167444480, vms=35735547904, shared=48979968, text=2732032, lib=0, data=244428800, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - On rank 1, 1th sub-tree assignment has 989 labels: [14, 20, 30, 46, 50, 54, 60, 78, 85, 100]...\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Starts creating label embedding PIFA for 1th sub-tree on rank 1... RSS 159.7 MB. Full mem info: pmem(rss=167444480, vms=35735547904, shared=48979968, text=2732032, lib=0, data=244428800, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Done generating 1th sub-tree cluster on rank 0. RSS 179.4 MB. Full mem info: pmem(rss=188153856, vms=35754147840, shared=49242112, text=2732032, lib=0, data=264515584, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Done creating label embedding PIFA for 1th sub-tree on rank 1. RSS 176.8 MB. Full mem info: pmem(rss=185393152, vms=35751538688, shared=49176576, text=2732032, lib=0, data=261906432, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Starts generating 1th sub-tree cluster on rank 1...\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Done generating 1th sub-tree cluster on rank 1. RSS 176.8 MB. Full mem info: pmem(rss=185393152, vms=35751538688, shared=49176576, text=2732032, lib=0, data=261906432, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Starts assmebling cluster chain... RSS 129.2 MB. Full mem info: pmem(rss=135426048, vms=35701280768, shared=49242112, text=2732032, lib=0, data=211648512, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Done assmebling cluster chain. Split depth: 1. Chain length: 3 RSS 129.2 MB. Full mem info: pmem(rss=135426048, vms=35701280768, shared=49242112, text=2732032, lib=0, data=211648512, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Broadcasting distributed cluster chain from Node 0...\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - Done broadcast distributed cluster chain from Node 0.\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.xlinear.model - meta, sub negative samples: 32 61\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.xlinear.model - Starts receiving sub-training jobs from source 0 for rank 1...\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - meta_tree_leaf_cluster: (3956, 64)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.xlinear.model - Main node workload: 69941.31147540984\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.xlinear.model - Min worker node workload, machine rank: (69387, 0). Max worker node workload, machine rank: (69387, 0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.xlinear.model - Training jobs for all Sub-trees divided onto 2 machines: Main node will train for 13 sub-trees, Worker nodes will train for [51] sub-trees, worker receive order: [1].\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.xlinear.model - Starts sending sub-training jobs from node 0 to 1...\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.xlinear.model - Done sending sub-training jobs from node 0 to 1.\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.xlinear.model - Rank 0 starts meta-tree training... RSS 130.0 MB. Full mem info: pmem(rss=136282112, vms=35702648832, shared=49307648, text=2732032, lib=0, data=213016576, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.xlinear.model - Done receiving sub-training jobs from source 0 for rank 1.\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.xlinear.model - Rank 1 get 51 sub-trees to train\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.xlinear.model - Rank 1 starts sub-tree training... RSS 129.0 MB. Full mem info: pmem(rss=135274496, vms=35701280768, shared=49176576, text=2732032, lib=0, data=211648512, dirty=0)\n", - "07/07/2022 22:33:35 - INFO - pecos.distributed.xmc.base - meta_tree_leaf_cluster: (3956, 64)\n", - "07/07/2022 22:33:39 - INFO - pecos.distributed.xmc.xlinear.model - Rank 0 done meta-tree training. RSS 163.5 MB. Full mem info: pmem(rss=171479040, vms=35735203840, shared=49307648, text=2732032, lib=0, data=250142720, dirty=0)\n", - "07/07/2022 22:33:39 - INFO - pecos.distributed.xmc.xlinear.model - Rank 0 get 13 sub-trees to train\n", - "07/07/2022 22:33:39 - INFO - pecos.distributed.xmc.xlinear.model - Rank 0 starts sub-tree training... RSS 163.5 MB. Full mem info: pmem(rss=171479040, vms=35735203840, shared=49307648, text=2732032, lib=0, data=250142720, dirty=0)\n", - "07/07/2022 22:33:42 - INFO - pecos.distributed.xmc.xlinear.model - Rank 0 total 13 sub-tree training finished. RSS 163.5 MB. Full mem info: pmem(rss=171479040, vms=35735203840, shared=49307648, text=2732032, lib=0, data=250142720, dirty=0)\n", - "07/07/2022 22:33:42 - INFO - pecos.distributed.xmc.xlinear.model - Main node start recv 51 sub-tree models from rank 1\n", - "07/07/2022 22:33:48 - INFO - pecos.distributed.xmc.xlinear.model - Rank 1 total 51 sub-tree training finished. RSS 148.8 MB. Full mem info: pmem(rss=155975680, vms=35721224192, shared=49369088, text=2732032, lib=0, data=232181760, dirty=0)\n", - "07/07/2022 22:33:48 - INFO - pecos.distributed.xmc.xlinear.model - Rank 1 node starts sending 51 sub-tree models.\n", - "07/07/2022 22:33:48 - INFO - pecos.distributed.xmc.xlinear.model - Main node done receive 51 sub-tree models from rank 1\n", - "07/07/2022 22:33:48 - INFO - pecos.distributed.xmc.xlinear.model - Rank 1 node done sending 51 sub-tree models.\n", - "07/07/2022 22:33:48 - INFO - pecos.distributed.xmc.xlinear.model - Reconstruct full model on Rank 0 node... RSS 163.9 MB. Full mem info: pmem(rss=171864064, vms=35735465984, shared=49446912, text=2732032, lib=0, data=250404864, dirty=0)\n", - "07/07/2022 22:33:48 - INFO - pecos.distributed.xmc.xlinear.model - Done reconstruct full model on Rank 0 node. RSS 164.7 MB. Full mem info: pmem(rss=172675072, vms=35735990272, shared=49446912, text=2732032, lib=0, data=250929152, dirty=0)\n", - "07/07/2022 22:33:48 - INFO - __main__ - Saving model to work_dir/dist_xlinear_model...\n", - "07/07/2022 22:33:49 - INFO - __main__ - Done saving model.\n" + "08/05/2022 01:53:16 - INFO - pecos.utils.profile_util - psutil module installed, will print memory info.\n", + "08/05/2022 01:53:16 - INFO - pecos.utils.profile_util - psutil module installed, will print memory info.\n", + "08/05/2022 01:53:16 - INFO - __main__ - Started loading data on Rank 0 ... RSS 89.9 MB. Full mem info: pmem(rss=94277632, vms=805863424, shared=45461504, text=2732032, lib=0, data=156913664, dirty=0)\n", + "08/05/2022 01:53:16 - INFO - __main__ - Started loading data on Rank 1 ... RSS 89.7 MB. Full mem info: pmem(rss=94011392, vms=805867520, shared=45207552, text=2732032, lib=0, data=156917760, dirty=0)\n", + "08/05/2022 01:53:16 - INFO - __main__ - Done loading data on Rank 1. RSS 166.4 MB. Full mem info: pmem(rss=174432256, vms=884928512, shared=45535232, text=2732032, lib=0, data=235978752, dirty=0)\n", + "08/05/2022 01:53:16 - INFO - __main__ - Done loading data on Rank 0. RSS 166.5 MB. Full mem info: pmem(rss=174637056, vms=884924416, shared=45735936, text=2732032, lib=0, data=235974656, dirty=0)\n", + "08/05/2022 01:53:16 - INFO - pecos.distributed.xmc.base - Starts creating label embedding PIFA for meta tree on Rank 0 node... RSS 166.5 MB. Full mem info: pmem(rss=174637056, vms=884924416, shared=45735936, text=2732032, lib=0, data=235974656, dirty=0)\n", + "08/05/2022 01:53:19 - INFO - pecos.distributed.xmc.base - Done creating label embedding PIFA for meta tree on Rank 0 node. RSS 1005.2 MB. Full mem info: pmem(rss=1054044160, vms=1838071808, shared=46354432, text=2732032, lib=0, data=1123237888, dirty=0)\n", + "08/05/2022 01:53:19 - INFO - pecos.distributed.xmc.base - Starts generating meta tree cluster on main node...\n", + "08/05/2022 01:53:19 - INFO - pecos.distributed.xmc.base - Determined meta-tree leaf clusters number: 2. 2 nodes will train 2 sub-trees. Number of data labels: 30938, nr_splits: 16\n", + "08/05/2022 01:53:23 - INFO - pecos.distributed.xmc.base - Done generating meta tree cluster. RSS 1005.7 MB. Full mem info: pmem(rss=1054523392, vms=1838071808, shared=46813184, text=2732032, lib=0, data=1123237888, dirty=0)\n", + "08/05/2022 01:53:23 - INFO - pecos.distributed.xmc.base - Rank 0 get 1 sub-tree assignments.\n", + "08/05/2022 01:53:23 - INFO - pecos.distributed.xmc.base - On rank 0, 0th sub-tree assignment has 15469 labels: [2, 4, 6, 8, 9, 10, 12, 14, 15, 18]...\n", + "08/05/2022 01:53:23 - INFO - pecos.distributed.xmc.base - Rank 1 get 1 sub-tree assignments.\n", + "08/05/2022 01:53:23 - INFO - pecos.distributed.xmc.base - On rank 1, 0th sub-tree assignment has 15469 labels: [0, 1, 3, 5, 7, 11, 13, 16, 17, 19]...\n", + "08/05/2022 01:53:23 - INFO - pecos.distributed.xmc.base - Starts creating label embedding PIFA for 0th sub-tree on rank 0... RSS 170.4 MB. Full mem info: pmem(rss=178642944, vms=962002944, shared=47120384, text=2732032, lib=0, data=247169024, dirty=0)\n", + "08/05/2022 01:53:23 - INFO - pecos.distributed.xmc.base - Starts creating label embedding PIFA for 0th sub-tree on rank 1... RSS 166.8 MB. Full mem info: pmem(rss=174923776, vms=885452800, shared=45776896, text=2732032, lib=0, data=236503040, dirty=0)\n", + "08/05/2022 01:53:25 - INFO - pecos.distributed.xmc.base - Done creating label embedding PIFA for 0th sub-tree on rank 1. RSS 473.8 MB. Full mem info: pmem(rss=496820224, vms=1280671744, shared=46534656, text=2732032, lib=0, data=565837824, dirty=0)\n", + "08/05/2022 01:53:25 - INFO - pecos.distributed.xmc.base - Starts generating 0th sub-tree cluster on rank 1...\n", + "08/05/2022 01:53:26 - INFO - pecos.distributed.xmc.base - Done creating label embedding PIFA for 0th sub-tree on rank 0. RSS 706.0 MB. Full mem info: pmem(rss=740335616, vms=1523580928, shared=47185920, text=2732032, lib=0, data=808747008, dirty=0)\n", + "08/05/2022 01:53:26 - INFO - pecos.distributed.xmc.base - Starts generating 0th sub-tree cluster on rank 0...\n", + "08/05/2022 01:53:31 - INFO - pecos.distributed.xmc.base - Done generating 0th sub-tree cluster on rank 1. RSS 474.4 MB. Full mem info: pmem(rss=497467392, vms=1280671744, shared=47153152, text=2732032, lib=0, data=565837824, dirty=0)\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.base - Done generating 0th sub-tree cluster on rank 0. RSS 706.1 MB. Full mem info: pmem(rss=740356096, vms=1523580928, shared=47194112, text=2732032, lib=0, data=808747008, dirty=0)\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.base - Starts assmebling cluster chain... RSS 172.2 MB. Full mem info: pmem(rss=180523008, vms=963616768, shared=47194112, text=2732032, lib=0, data=248782848, dirty=0)\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.base - Done assmebling cluster chain. Split depth: 1. Chain length: 4 RSS 172.2 MB. Full mem info: pmem(rss=180523008, vms=963616768, shared=47194112, text=2732032, lib=0, data=248782848, dirty=0)\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.base - Broadcasting distributed cluster chain from Node 0...\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.base - Done broadcast distributed cluster chain from Node 0.\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.xlinear.model - Starts receiving sub-training jobs from source 0 for rank 1...\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.xlinear.model - meta, sub negative samples: 32 76\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.base - meta_tree_leaf_cluster: (30938, 32)\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.xlinear.model - Main node workload: 216470.70175438595\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.xlinear.model - Min worker node workload, machine rank: (204720, 0). Max worker node workload, machine rank: (204720, 0)\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.xlinear.model - Training jobs for all Sub-trees divided onto 2 machines: Main node will train for 7 sub-trees, Worker nodes will train for [25] sub-trees, worker receive order: [1].\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.xlinear.model - Starts sending sub-training jobs from node 0 to 1...\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.xlinear.model - Done sending sub-training jobs from node 0 to 1.\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.xlinear.model - Rank 0 starts meta-tree training... RSS 173.0 MB. Full mem info: pmem(rss=181362688, vms=964771840, shared=47194112, text=2732032, lib=0, data=249937920, dirty=0)\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.xlinear.model - Done receiving sub-training jobs from source 0 for rank 1.\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.xlinear.model - Rank 1 get 25 sub-trees to train\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.xlinear.model - Rank 1 starts sub-tree training... RSS 170.2 MB. Full mem info: pmem(rss=178483200, vms=961769472, shared=47218688, text=2732032, lib=0, data=246935552, dirty=0)\n", + "08/05/2022 01:53:37 - INFO - pecos.distributed.xmc.base - meta_tree_leaf_cluster: (30938, 32)\n", + "08/05/2022 01:53:52 - INFO - pecos.distributed.xmc.xlinear.model - Rank 0 done meta-tree training. RSS 195.2 MB. Full mem info: pmem(rss=204632064, vms=983449600, shared=47259648, text=2732032, lib=0, data=273154048, dirty=0)\n", + "08/05/2022 01:53:52 - INFO - pecos.distributed.xmc.xlinear.model - Rank 0 get 7 sub-trees to train\n", + "08/05/2022 01:53:52 - INFO - pecos.distributed.xmc.xlinear.model - Rank 0 starts sub-tree training... RSS 195.2 MB. Full mem info: pmem(rss=204632064, vms=983449600, shared=47259648, text=2732032, lib=0, data=273154048, dirty=0)\n", + "08/05/2022 01:54:37 - INFO - pecos.distributed.xmc.xlinear.model - Rank 0 total 7 sub-tree training finished. RSS 239.9 MB. Full mem info: pmem(rss=251518976, vms=1030242304, shared=47325184, text=2732032, lib=0, data=319946752, dirty=0)\n", + "08/05/2022 01:54:37 - INFO - pecos.distributed.xmc.xlinear.model - Main node start recv 25 sub-tree models from rank 1\n", + "08/05/2022 01:55:42 - INFO - pecos.distributed.xmc.xlinear.model - Rank 1 total 25 sub-tree training finished. RSS 251.5 MB. Full mem info: pmem(rss=263733248, vms=1044664320, shared=47349760, text=2732032, lib=0, data=334344192, dirty=0)\n", + "08/05/2022 01:55:42 - INFO - pecos.distributed.xmc.xlinear.model - Rank 1 node starts sending 25 sub-tree models.\n", + "08/05/2022 01:55:42 - INFO - pecos.distributed.xmc.xlinear.model - Main node done receive 25 sub-tree models from rank 1\n", + "08/05/2022 01:55:42 - INFO - pecos.distributed.xmc.xlinear.model - Rank 1 node done sending 25 sub-tree models.\n", + "08/05/2022 01:55:42 - INFO - pecos.distributed.xmc.xlinear.model - Reconstruct full model on Rank 0 node... RSS 240.2 MB. Full mem info: pmem(rss=251850752, vms=1030242304, shared=47398912, text=2732032, lib=0, data=319946752, dirty=0)\n", + "08/05/2022 01:55:42 - INFO - pecos.distributed.xmc.xlinear.model - Done reconstruct full model on Rank 0 node. RSS 315.7 MB. Full mem info: pmem(rss=331001856, vms=1109413888, shared=47398912, text=2732032, lib=0, data=399118336, dirty=0)\n", + "08/05/2022 01:55:42 - INFO - __main__ - Saving model to work_dir/dist_xlinear_model...\n", + "08/05/2022 01:55:42 - INFO - __main__ - Done saving model.\n" ] } ], @@ -800,8 +840,8 @@ "%%bash\n", "mpiexec -n 2 \\\n", "python3 -m pecos.distributed.xmc.xlinear.train \\\n", - "-x work_dir/xmc-base/eurlex-4k/X.trn.npz \\\n", - "-y work_dir/xmc-base/eurlex-4k/Y.trn.npz \\\n", + "-x xmc-base/wiki10-31k/tfidf-attnxml/X.trn.npz \\\n", + "-y xmc-base/wiki10-31k/Y.trn.npz \\\n", "-m work_dir/dist_xlinear_model" ] }, @@ -820,7 +860,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 15, "id": "b39cbf96-6cf2-4721-86d7-90cc27bdbe1f", "metadata": {}, "outputs": [ @@ -829,16 +869,16 @@ "output_type": "stream", "text": [ "==== evaluation results ====\n", - "prec = 82.25 74.92 68.58 62.92 57.58 52.55 47.68 43.56 40.08 37.08\n", - "recall = 16.62 29.97 40.75 49.33 55.96 60.91 64.26 66.97 69.14 70.97\n" + "prec = 84.05 78.20 72.57 67.90 63.90 60.17 56.86 53.87 51.31 48.82\n", + "recall = 4.97 9.17 12.63 15.62 18.26 20.52 22.48 24.25 25.88 27.25\n" ] } ], "source": [ "%%bash\n", "python3 -m pecos.xmc.xlinear.predict \\\n", - "-x work_dir/xmc-base/eurlex-4k/X.tst.npz \\\n", - "-y work_dir/xmc-base/eurlex-4k/Y.tst.npz \\\n", + "-x xmc-base/wiki10-31k/tfidf-attnxml/X.tst.npz \\\n", + "-y xmc-base/wiki10-31k/Y.tst.npz \\\n", "-m work_dir/dist_xlinear_model" ] }, @@ -886,7 +926,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.10" + "version": "3.9.4" } }, "nbformat": 4,