From 6b8cf35f36fbcd73ba4fa99f37e61547259a7960 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 24 Jul 2020 14:03:21 +0800 Subject: [PATCH 1/2] Add API to query intermediate results --- .../nni/nas/benchmarks/nasbench101/model.py | 3 +- .../nni/nas/benchmarks/nasbench101/query.py | 16 ++++++++-- .../nni/nas/benchmarks/nasbench201/model.py | 3 +- .../nni/nas/benchmarks/nasbench201/query.py | 31 +++++++++++++++++-- src/sdk/pynni/nni/nas/benchmarks/nds/model.py | 5 +-- src/sdk/pynni/nni/nas/benchmarks/nds/query.py | 16 ++++++++-- src/sdk/pynni/nni/nas/benchmarks/utils.py | 5 +++ 7 files changed, 66 insertions(+), 13 deletions(-) create mode 100644 src/sdk/pynni/nni/nas/benchmarks/utils.py diff --git a/src/sdk/pynni/nni/nas/benchmarks/nasbench101/model.py b/src/sdk/pynni/nni/nas/benchmarks/nasbench101/model.py index 00a2596a5f..44ec3f874f 100644 --- a/src/sdk/pynni/nni/nas/benchmarks/nasbench101/model.py +++ b/src/sdk/pynni/nni/nas/benchmarks/nasbench101/model.py @@ -4,6 +4,7 @@ from playhouse.sqlite_ext import JSONField, SqliteExtDatabase from nni.nas.benchmarks.constants import DATABASE_DIR +from nni.nas.benchmarks.utils import json_dumps db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True) @@ -28,7 +29,7 @@ class Nb101TrialConfig(Model): Number of epochs planned for this trial. Should be one of 4, 12, 36, 108 in default setup. """ - arch = JSONField(index=True) + arch = JSONField(json_dumps=json_dumps, index=True) num_vertices = IntegerField(index=True) hash = CharField(max_length=64, index=True) num_epochs = IntegerField(index=True) diff --git a/src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py b/src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py index 28c0dc03be..1c54c24f46 100644 --- a/src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py +++ b/src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py @@ -6,7 +6,7 @@ from .graph_util import hash_module, infer_num_vertices -def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None): +def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None, include_intermediates=False): """ Query trial stats of NAS-Bench-101 given conditions. @@ -24,6 +24,8 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None): reduction : str or None If 'none' or None, all trial stats will be returned directly. If 'mean', fields in trial stats will be averaged given the same trial config. + include_intermediates : boolean + If true, intermediate results will be returned. Returns ------- @@ -56,5 +58,13 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None): query = query.where(functools.reduce(lambda a, b: a & b, conditions)) if reduction is not None: query = query.group_by(Nb101TrialStats.config) - for k in query: - yield model_to_dict(k) + for trial in query: + if include_intermediates: + data = model_to_dict(trial) + # exclude 'trial' from intermediates as it is already available in data + data['intermediates'] = [ + {k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates + ] + yield data + else: + yield model_to_dict(trial) diff --git a/src/sdk/pynni/nni/nas/benchmarks/nasbench201/model.py b/src/sdk/pynni/nni/nas/benchmarks/nasbench201/model.py index 3e3322f921..3b898de7c8 100644 --- a/src/sdk/pynni/nni/nas/benchmarks/nasbench201/model.py +++ b/src/sdk/pynni/nni/nas/benchmarks/nasbench201/model.py @@ -4,6 +4,7 @@ from playhouse.sqlite_ext import JSONField, SqliteExtDatabase from nni.nas.benchmarks.constants import DATABASE_DIR +from nni.nas.benchmarks.utils import json_dumps db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench201.db'), autoconnect=True) @@ -35,7 +36,7 @@ class Nb201TrialConfig(Model): for training, 6k images from validation set for validation and the other 6k for testing). """ - arch = JSONField(index=True) + arch = JSONField(json_dumps=json_dumps, index=True) num_epochs = IntegerField(index=True) num_channels = IntegerField() num_cells = IntegerField() diff --git a/src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py b/src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py index 3e5a29dea0..6b51725d6d 100644 --- a/src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py +++ b/src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py @@ -5,7 +5,7 @@ from .model import Nb201TrialStats, Nb201TrialConfig -def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None): +def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_intermediates=False): """ Query trial stats of NAS-Bench-201 given conditions. @@ -23,6 +23,8 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None): reduction : str or None If 'none' or None, all trial stats will be returned directly. If 'mean', fields in trial stats will be averaged given the same trial config. + include_intermediates : boolean + If true, intermediate results will be returned. Returns ------- @@ -53,5 +55,28 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None): query = query.where(functools.reduce(lambda a, b: a & b, conditions)) if reduction is not None: query = query.group_by(Nb201TrialStats.config) - for k in query: - yield model_to_dict(k) + for trial in query: + if include_intermediates: + data = model_to_dict(trial) + # exclude 'trial' from intermediates as it is already available in data + data['intermediates'] = [ + {k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates + ] + yield data + else: + yield model_to_dict(trial) + + +if __name__ == "__main__": + import pprint + arch = { + '0_1': 'avg_pool_3x3', + '0_2': 'conv_1x1', + '0_3': 'conv_1x1', + '1_2': 'skip_connect', + '1_3': 'skip_connect', + '2_3': 'skip_connect' + } + for t in query_nb201_trial_stats(arch, 200, 'cifar100', include_intermediates=False): + pprint.pprint(t) + break diff --git a/src/sdk/pynni/nni/nas/benchmarks/nds/model.py b/src/sdk/pynni/nni/nas/benchmarks/nds/model.py index d7f6894da1..a6ace351d7 100644 --- a/src/sdk/pynni/nni/nas/benchmarks/nds/model.py +++ b/src/sdk/pynni/nni/nas/benchmarks/nds/model.py @@ -4,6 +4,7 @@ from playhouse.sqlite_ext import JSONField, SqliteExtDatabase from nni.nas.benchmarks.constants import DATABASE_DIR +from nni.nas.benchmarks.utils import json_dumps db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nds.db'), autoconnect=True) @@ -52,8 +53,8 @@ class NdsTrialConfig(Model): 'residual_basic', 'vanilla', ]) - model_spec = JSONField(index=True) - cell_spec = JSONField(index=True, null=True) + model_spec = JSONField(json_dumps=json_dumps, index=True) + cell_spec = JSONField(json_dumps=json_dumps, index=True, null=True) dataset = CharField(max_length=15, index=True, choices=['cifar10', 'imagenet']) generator = CharField(max_length=15, index=True, choices=[ 'random', diff --git a/src/sdk/pynni/nni/nas/benchmarks/nds/query.py b/src/sdk/pynni/nni/nas/benchmarks/nds/query.py index 618b9e57b3..fe192ba509 100644 --- a/src/sdk/pynni/nni/nas/benchmarks/nds/query.py +++ b/src/sdk/pynni/nni/nas/benchmarks/nds/query.py @@ -6,7 +6,7 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_spec, dataset, - num_epochs=None, reduction=None): + num_epochs=None, reduction=None, include_intermediates=False): """ Query trial stats of NDS given conditions. @@ -32,6 +32,8 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp reduction : str or None If 'none' or None, all trial stats will be returned directly. If 'mean', fields in trial stats will be averaged given the same trial config. + include_intermediates : boolean + If true, intermediate results will be returned. Returns ------- @@ -60,5 +62,13 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp query = query.where(functools.reduce(lambda a, b: a & b, conditions)) if reduction is not None: query = query.group_by(NdsTrialStats.config) - for k in query: - yield model_to_dict(k) + for trial in query: + if include_intermediates: + data = model_to_dict(trial) + # exclude 'trial' from intermediates as it is already available in data + data['intermediates'] = [ + {k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates + ] + yield data + else: + yield model_to_dict(trial) diff --git a/src/sdk/pynni/nni/nas/benchmarks/utils.py b/src/sdk/pynni/nni/nas/benchmarks/utils.py new file mode 100644 index 0000000000..7189d52f5d --- /dev/null +++ b/src/sdk/pynni/nni/nas/benchmarks/utils.py @@ -0,0 +1,5 @@ +import functools +import json + + +json_dumps = functools.partial(json.dumps, sort_keys=True) From 80865fddbf7fb59aed6db0a14e4802a5906f2453 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Sun, 26 Jul 2020 15:42:23 +0800 Subject: [PATCH 2/2] Finalize example --- docs/en_US/NAS/BenchmarksExample.ipynb | 107 ++++++++++++++---- .../nni/nas/benchmarks/nasbench201/query.py | 15 --- 2 files changed, 88 insertions(+), 34 deletions(-) diff --git a/docs/en_US/NAS/BenchmarksExample.ipynb b/docs/en_US/NAS/BenchmarksExample.ipynb index 376da76f05..dd4f24a54b 100644 --- a/docs/en_US/NAS/BenchmarksExample.ipynb +++ b/docs/en_US/NAS/BenchmarksExample.ipynb @@ -34,19 +34,22 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Use the following architecture as an example:
\n", + "Use the following architecture as an example:\n", + "\n", "![nas-101](../../img/nas-bench-101-example.png)" ] }, { "cell_type": "code", "execution_count": 2, - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 10,\n 'parameters': 8.55553,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 106147.67578125,\n 'valid_acc': 92.41786599159241}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 11,\n 'parameters': 8.55553,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 106095.05859375,\n 'valid_acc': 92.45793223381042}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 12,\n 'parameters': 8.55553,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 106138.55712890625,\n 'valid_acc': 93.04887652397156}\n" + "text": "{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 10,\n 'intermediates': [{'current_epoch': 54,\n 'id': 19,\n 'test_acc': 77.40384340286255,\n 'train_acc': 82.82251358032227,\n 'training_time': 883.4580078125,\n 'valid_acc': 77.76442170143127},\n {'current_epoch': 108,\n 'id': 20,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 1769.1279296875,\n 'valid_acc': 92.41786599159241}],\n 'parameters': 8.55553,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 106147.67578125,\n 'valid_acc': 92.41786599159241}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 11,\n 'intermediates': [{'current_epoch': 54,\n 'id': 21,\n 'test_acc': 82.04126358032227,\n 'train_acc': 87.96073794364929,\n 'training_time': 883.6810302734375,\n 'valid_acc': 82.91265964508057},\n {'current_epoch': 108,\n 'id': 22,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 1768.2509765625,\n 'valid_acc': 92.45793223381042}],\n 'parameters': 8.55553,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 106095.05859375,\n 'valid_acc': 92.45793223381042}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 12,\n 'intermediates': [{'current_epoch': 54,\n 'id': 23,\n 'test_acc': 80.58894276618958,\n 'train_acc': 86.34815812110901,\n 'training_time': 883.4569702148438,\n 'valid_acc': 81.1598539352417},\n {'current_epoch': 108,\n 'id': 24,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 1768.9759521484375,\n 'valid_acc': 93.04887652397156}],\n 'parameters': 8.55553,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 106138.55712890625,\n 'valid_acc': 93.04887652397156}\n" } ], "source": [ @@ -63,7 +66,7 @@ " 'input5': [0, 3, 4],\n", " 'input6': [2, 5]\n", "}\n", - "for t in query_nb101_trial_stats(arch, 108):\n", + "for t in query_nb101_trial_stats(arch, 108, include_intermediates=True):\n", " pprint.pprint(t)" ] }, @@ -85,14 +88,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Use the following architecture as an example:
\n", + "Use the following architecture as an example:\n", + "\n", "![nas-201](../../img/nas-bench-201-example.png)" ] }, { "cell_type": "code", "execution_count": 3, - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", @@ -113,6 +119,32 @@ " pprint.pprint(t)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Intermediate results are also available." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "{'id': 4, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 12, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 12\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n" + } + ], + "source": [ + "for t in query_nb201_trial_stats(arch, None, 'imagenet16-120', include_intermediates=True):\n", + " print(t['config'])\n", + " print('Intermediates:', len(t['intermediates']))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -132,8 +164,10 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": {}, + "execution_count": 5, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", @@ -156,8 +190,35 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": {}, + "execution_count": 6, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "[{'current_epoch': 1,\n 'id': 4494501,\n 'test_acc': 41.76,\n 'train_acc': 30.421000000000006,\n 'train_loss': 1.793},\n {'current_epoch': 2,\n 'id': 4494502,\n 'test_acc': 54.66,\n 'train_acc': 47.24,\n 'train_loss': 1.415},\n {'current_epoch': 3,\n 'id': 4494503,\n 'test_acc': 59.97,\n 'train_acc': 56.983,\n 'train_loss': 1.179},\n {'current_epoch': 4,\n 'id': 4494504,\n 'test_acc': 62.91,\n 'train_acc': 61.955,\n 'train_loss': 1.048},\n {'current_epoch': 5,\n 'id': 4494505,\n 'test_acc': 66.16,\n 'train_acc': 64.493,\n 'train_loss': 0.983},\n {'current_epoch': 6,\n 'id': 4494506,\n 'test_acc': 66.5,\n 'train_acc': 66.274,\n 'train_loss': 0.937},\n {'current_epoch': 7,\n 'id': 4494507,\n 'test_acc': 67.55,\n 'train_acc': 67.426,\n 'train_loss': 0.907},\n {'current_epoch': 8,\n 'id': 4494508,\n 'test_acc': 69.45,\n 'train_acc': 68.45400000000001,\n 'train_loss': 0.878},\n {'current_epoch': 9,\n 'id': 4494509,\n 'test_acc': 70.14,\n 'train_acc': 69.295,\n 'train_loss': 0.857},\n {'current_epoch': 10,\n 'id': 4494510,\n 'test_acc': 69.47,\n 'train_acc': 70.304,\n 'train_loss': 0.832}]\n" + } + ], + "source": [ + "model_spec = {\n", + " 'bot_muls': [0.0, 0.25, 0.25, 0.25],\n", + " 'ds': [1, 16, 1, 4],\n", + " 'num_gs': [1, 2, 1, 2],\n", + " 'ss': [1, 1, 2, 2],\n", + " 'ws': [16, 64, 128, 16]\n", + "}\n", + "for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10', include_intermediates=True):\n", + " pprint.pprint(t['intermediates'][:10])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", @@ -173,8 +234,10 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": {}, + "execution_count": 8, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", @@ -189,8 +252,10 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, + "execution_count": 9, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", @@ -254,8 +319,10 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": {}, + "execution_count": 10, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", @@ -270,13 +337,15 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": {}, + "execution_count": 11, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "Elapsed time: 1.9107539653778076 seconds\n" + "text": "Elapsed time: 2.2023813724517822 seconds\n" } ], "source": [ diff --git a/src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py b/src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py index 6b51725d6d..3272590614 100644 --- a/src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py +++ b/src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py @@ -65,18 +65,3 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i yield data else: yield model_to_dict(trial) - - -if __name__ == "__main__": - import pprint - arch = { - '0_1': 'avg_pool_3x3', - '0_2': 'conv_1x1', - '0_3': 'conv_1x1', - '1_2': 'skip_connect', - '1_3': 'skip_connect', - '2_3': 'skip_connect' - } - for t in query_nb201_trial_stats(arch, 200, 'cifar100', include_intermediates=False): - pprint.pprint(t) - break