diff --git a/docs/en_US/NAS/BenchmarksExample.ipynb b/docs/en_US/NAS/BenchmarksExample.ipynb
index 376da76f05..dd4f24a54b 100644
--- a/docs/en_US/NAS/BenchmarksExample.ipynb
+++ b/docs/en_US/NAS/BenchmarksExample.ipynb
@@ -34,19 +34,22 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Use the following architecture as an example:
\n",
+ "Use the following architecture as an example:\n",
+ "\n",
"![nas-101](../../img/nas-bench-101-example.png)"
]
},
{
"cell_type": "code",
"execution_count": 2,
- "metadata": {},
+ "metadata": {
+ "tags": []
+ },
"outputs": [
{
"output_type": "stream",
"name": "stdout",
- "text": "{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 10,\n 'parameters': 8.55553,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 106147.67578125,\n 'valid_acc': 92.41786599159241}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 11,\n 'parameters': 8.55553,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 106095.05859375,\n 'valid_acc': 92.45793223381042}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 12,\n 'parameters': 8.55553,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 106138.55712890625,\n 'valid_acc': 93.04887652397156}\n"
+ "text": "{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 10,\n 'intermediates': [{'current_epoch': 54,\n 'id': 19,\n 'test_acc': 77.40384340286255,\n 'train_acc': 82.82251358032227,\n 'training_time': 883.4580078125,\n 'valid_acc': 77.76442170143127},\n {'current_epoch': 108,\n 'id': 20,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 1769.1279296875,\n 'valid_acc': 92.41786599159241}],\n 'parameters': 8.55553,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 106147.67578125,\n 'valid_acc': 92.41786599159241}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 11,\n 'intermediates': [{'current_epoch': 54,\n 'id': 21,\n 'test_acc': 82.04126358032227,\n 'train_acc': 87.96073794364929,\n 'training_time': 883.6810302734375,\n 'valid_acc': 82.91265964508057},\n {'current_epoch': 108,\n 'id': 22,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 1768.2509765625,\n 'valid_acc': 92.45793223381042}],\n 'parameters': 8.55553,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 106095.05859375,\n 'valid_acc': 92.45793223381042}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 12,\n 'intermediates': [{'current_epoch': 54,\n 'id': 23,\n 'test_acc': 80.58894276618958,\n 'train_acc': 86.34815812110901,\n 'training_time': 883.4569702148438,\n 'valid_acc': 81.1598539352417},\n {'current_epoch': 108,\n 'id': 24,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 1768.9759521484375,\n 'valid_acc': 93.04887652397156}],\n 'parameters': 8.55553,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 106138.55712890625,\n 'valid_acc': 93.04887652397156}\n"
}
],
"source": [
@@ -63,7 +66,7 @@
" 'input5': [0, 3, 4],\n",
" 'input6': [2, 5]\n",
"}\n",
- "for t in query_nb101_trial_stats(arch, 108):\n",
+ "for t in query_nb101_trial_stats(arch, 108, include_intermediates=True):\n",
" pprint.pprint(t)"
]
},
@@ -85,14 +88,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Use the following architecture as an example:
\n",
+ "Use the following architecture as an example:\n",
+ "\n",
"![nas-201](../../img/nas-bench-201-example.png)"
]
},
{
"cell_type": "code",
"execution_count": 3,
- "metadata": {},
+ "metadata": {
+ "tags": []
+ },
"outputs": [
{
"output_type": "stream",
@@ -113,6 +119,32 @@
" pprint.pprint(t)"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Intermediate results are also available."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": "{'id': 4, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 12, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 12\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n"
+ }
+ ],
+ "source": [
+ "for t in query_nb201_trial_stats(arch, None, 'imagenet16-120', include_intermediates=True):\n",
+ " print(t['config'])\n",
+ " print('Intermediates:', len(t['intermediates']))"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -132,8 +164,10 @@
},
{
"cell_type": "code",
- "execution_count": 4,
- "metadata": {},
+ "execution_count": 5,
+ "metadata": {
+ "tags": []
+ },
"outputs": [
{
"output_type": "stream",
@@ -156,8 +190,35 @@
},
{
"cell_type": "code",
- "execution_count": 5,
- "metadata": {},
+ "execution_count": 6,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": "[{'current_epoch': 1,\n 'id': 4494501,\n 'test_acc': 41.76,\n 'train_acc': 30.421000000000006,\n 'train_loss': 1.793},\n {'current_epoch': 2,\n 'id': 4494502,\n 'test_acc': 54.66,\n 'train_acc': 47.24,\n 'train_loss': 1.415},\n {'current_epoch': 3,\n 'id': 4494503,\n 'test_acc': 59.97,\n 'train_acc': 56.983,\n 'train_loss': 1.179},\n {'current_epoch': 4,\n 'id': 4494504,\n 'test_acc': 62.91,\n 'train_acc': 61.955,\n 'train_loss': 1.048},\n {'current_epoch': 5,\n 'id': 4494505,\n 'test_acc': 66.16,\n 'train_acc': 64.493,\n 'train_loss': 0.983},\n {'current_epoch': 6,\n 'id': 4494506,\n 'test_acc': 66.5,\n 'train_acc': 66.274,\n 'train_loss': 0.937},\n {'current_epoch': 7,\n 'id': 4494507,\n 'test_acc': 67.55,\n 'train_acc': 67.426,\n 'train_loss': 0.907},\n {'current_epoch': 8,\n 'id': 4494508,\n 'test_acc': 69.45,\n 'train_acc': 68.45400000000001,\n 'train_loss': 0.878},\n {'current_epoch': 9,\n 'id': 4494509,\n 'test_acc': 70.14,\n 'train_acc': 69.295,\n 'train_loss': 0.857},\n {'current_epoch': 10,\n 'id': 4494510,\n 'test_acc': 69.47,\n 'train_acc': 70.304,\n 'train_loss': 0.832}]\n"
+ }
+ ],
+ "source": [
+ "model_spec = {\n",
+ " 'bot_muls': [0.0, 0.25, 0.25, 0.25],\n",
+ " 'ds': [1, 16, 1, 4],\n",
+ " 'num_gs': [1, 2, 1, 2],\n",
+ " 'ss': [1, 1, 2, 2],\n",
+ " 'ws': [16, 64, 128, 16]\n",
+ "}\n",
+ "for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10', include_intermediates=True):\n",
+ " pprint.pprint(t['intermediates'][:10])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "tags": []
+ },
"outputs": [
{
"output_type": "stream",
@@ -173,8 +234,10 @@
},
{
"cell_type": "code",
- "execution_count": 6,
- "metadata": {},
+ "execution_count": 8,
+ "metadata": {
+ "tags": []
+ },
"outputs": [
{
"output_type": "stream",
@@ -189,8 +252,10 @@
},
{
"cell_type": "code",
- "execution_count": 7,
- "metadata": {},
+ "execution_count": 9,
+ "metadata": {
+ "tags": []
+ },
"outputs": [
{
"output_type": "stream",
@@ -254,8 +319,10 @@
},
{
"cell_type": "code",
- "execution_count": 8,
- "metadata": {},
+ "execution_count": 10,
+ "metadata": {
+ "tags": []
+ },
"outputs": [
{
"output_type": "stream",
@@ -270,13 +337,15 @@
},
{
"cell_type": "code",
- "execution_count": 9,
- "metadata": {},
+ "execution_count": 11,
+ "metadata": {
+ "tags": []
+ },
"outputs": [
{
"output_type": "stream",
"name": "stdout",
- "text": "Elapsed time: 1.9107539653778076 seconds\n"
+ "text": "Elapsed time: 2.2023813724517822 seconds\n"
}
],
"source": [
diff --git a/src/sdk/pynni/nni/nas/benchmarks/nasbench101/model.py b/src/sdk/pynni/nni/nas/benchmarks/nasbench101/model.py
index 00a2596a5f..44ec3f874f 100644
--- a/src/sdk/pynni/nni/nas/benchmarks/nasbench101/model.py
+++ b/src/sdk/pynni/nni/nas/benchmarks/nasbench101/model.py
@@ -4,6 +4,7 @@
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR
+from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True)
@@ -28,7 +29,7 @@ class Nb101TrialConfig(Model):
Number of epochs planned for this trial. Should be one of 4, 12, 36, 108 in default setup.
"""
- arch = JSONField(index=True)
+ arch = JSONField(json_dumps=json_dumps, index=True)
num_vertices = IntegerField(index=True)
hash = CharField(max_length=64, index=True)
num_epochs = IntegerField(index=True)
diff --git a/src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py b/src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py
index 28c0dc03be..1c54c24f46 100644
--- a/src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py
+++ b/src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py
@@ -6,7 +6,7 @@
from .graph_util import hash_module, infer_num_vertices
-def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None):
+def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None, include_intermediates=False):
"""
Query trial stats of NAS-Bench-101 given conditions.
@@ -24,6 +24,8 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None):
reduction : str or None
If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config.
+ include_intermediates : boolean
+ If true, intermediate results will be returned.
Returns
-------
@@ -56,5 +58,13 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None):
query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None:
query = query.group_by(Nb101TrialStats.config)
- for k in query:
- yield model_to_dict(k)
+ for trial in query:
+ if include_intermediates:
+ data = model_to_dict(trial)
+ # exclude 'trial' from intermediates as it is already available in data
+ data['intermediates'] = [
+ {k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
+ ]
+ yield data
+ else:
+ yield model_to_dict(trial)
diff --git a/src/sdk/pynni/nni/nas/benchmarks/nasbench201/model.py b/src/sdk/pynni/nni/nas/benchmarks/nasbench201/model.py
index 3e3322f921..3b898de7c8 100644
--- a/src/sdk/pynni/nni/nas/benchmarks/nasbench201/model.py
+++ b/src/sdk/pynni/nni/nas/benchmarks/nasbench201/model.py
@@ -4,6 +4,7 @@
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR
+from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench201.db'), autoconnect=True)
@@ -35,7 +36,7 @@ class Nb201TrialConfig(Model):
for training, 6k images from validation set for validation and the other 6k for testing).
"""
- arch = JSONField(index=True)
+ arch = JSONField(json_dumps=json_dumps, index=True)
num_epochs = IntegerField(index=True)
num_channels = IntegerField()
num_cells = IntegerField()
diff --git a/src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py b/src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py
index 3e5a29dea0..3272590614 100644
--- a/src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py
+++ b/src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py
@@ -5,7 +5,7 @@
from .model import Nb201TrialStats, Nb201TrialConfig
-def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None):
+def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_intermediates=False):
"""
Query trial stats of NAS-Bench-201 given conditions.
@@ -23,6 +23,8 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None):
reduction : str or None
If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config.
+ include_intermediates : boolean
+ If true, intermediate results will be returned.
Returns
-------
@@ -53,5 +55,13 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None):
query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None:
query = query.group_by(Nb201TrialStats.config)
- for k in query:
- yield model_to_dict(k)
+ for trial in query:
+ if include_intermediates:
+ data = model_to_dict(trial)
+ # exclude 'trial' from intermediates as it is already available in data
+ data['intermediates'] = [
+ {k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
+ ]
+ yield data
+ else:
+ yield model_to_dict(trial)
diff --git a/src/sdk/pynni/nni/nas/benchmarks/nds/model.py b/src/sdk/pynni/nni/nas/benchmarks/nds/model.py
index d7f6894da1..a6ace351d7 100644
--- a/src/sdk/pynni/nni/nas/benchmarks/nds/model.py
+++ b/src/sdk/pynni/nni/nas/benchmarks/nds/model.py
@@ -4,6 +4,7 @@
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR
+from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nds.db'), autoconnect=True)
@@ -52,8 +53,8 @@ class NdsTrialConfig(Model):
'residual_basic',
'vanilla',
])
- model_spec = JSONField(index=True)
- cell_spec = JSONField(index=True, null=True)
+ model_spec = JSONField(json_dumps=json_dumps, index=True)
+ cell_spec = JSONField(json_dumps=json_dumps, index=True, null=True)
dataset = CharField(max_length=15, index=True, choices=['cifar10', 'imagenet'])
generator = CharField(max_length=15, index=True, choices=[
'random',
diff --git a/src/sdk/pynni/nni/nas/benchmarks/nds/query.py b/src/sdk/pynni/nni/nas/benchmarks/nds/query.py
index 618b9e57b3..fe192ba509 100644
--- a/src/sdk/pynni/nni/nas/benchmarks/nds/query.py
+++ b/src/sdk/pynni/nni/nas/benchmarks/nds/query.py
@@ -6,7 +6,7 @@
def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_spec, dataset,
- num_epochs=None, reduction=None):
+ num_epochs=None, reduction=None, include_intermediates=False):
"""
Query trial stats of NDS given conditions.
@@ -32,6 +32,8 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
reduction : str or None
If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config.
+ include_intermediates : boolean
+ If true, intermediate results will be returned.
Returns
-------
@@ -60,5 +62,13 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None:
query = query.group_by(NdsTrialStats.config)
- for k in query:
- yield model_to_dict(k)
+ for trial in query:
+ if include_intermediates:
+ data = model_to_dict(trial)
+ # exclude 'trial' from intermediates as it is already available in data
+ data['intermediates'] = [
+ {k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
+ ]
+ yield data
+ else:
+ yield model_to_dict(trial)
diff --git a/src/sdk/pynni/nni/nas/benchmarks/utils.py b/src/sdk/pynni/nni/nas/benchmarks/utils.py
new file mode 100644
index 0000000000..7189d52f5d
--- /dev/null
+++ b/src/sdk/pynni/nni/nas/benchmarks/utils.py
@@ -0,0 +1,5 @@
+import functools
+import json
+
+
+json_dumps = functools.partial(json.dumps, sort_keys=True)