Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add API to query intermediate results in NAS benchmark #2728

Merged
merged 2 commits into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 88 additions & 19 deletions docs/en_US/NAS/BenchmarksExample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,22 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the following architecture as an example:<br>\n",
"Use the following architecture as an example:\n",
"\n",
"![nas-101](../../img/nas-bench-101-example.png)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 10,\n 'parameters': 8.55553,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 106147.67578125,\n 'valid_acc': 92.41786599159241}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 11,\n 'parameters': 8.55553,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 106095.05859375,\n 'valid_acc': 92.45793223381042}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 12,\n 'parameters': 8.55553,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 106138.55712890625,\n 'valid_acc': 93.04887652397156}\n"
"text": "{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 10,\n 'intermediates': [{'current_epoch': 54,\n 'id': 19,\n 'test_acc': 77.40384340286255,\n 'train_acc': 82.82251358032227,\n 'training_time': 883.4580078125,\n 'valid_acc': 77.76442170143127},\n {'current_epoch': 108,\n 'id': 20,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 1769.1279296875,\n 'valid_acc': 92.41786599159241}],\n 'parameters': 8.55553,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 106147.67578125,\n 'valid_acc': 92.41786599159241}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 11,\n 'intermediates': [{'current_epoch': 54,\n 'id': 21,\n 'test_acc': 82.04126358032227,\n 'train_acc': 87.96073794364929,\n 'training_time': 883.6810302734375,\n 'valid_acc': 82.91265964508057},\n {'current_epoch': 108,\n 'id': 22,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 1768.2509765625,\n 'valid_acc': 92.45793223381042}],\n 'parameters': 8.55553,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 106095.05859375,\n 'valid_acc': 92.45793223381042}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 12,\n 'intermediates': [{'current_epoch': 54,\n 'id': 23,\n 'test_acc': 80.58894276618958,\n 'train_acc': 86.34815812110901,\n 'training_time': 883.4569702148438,\n 'valid_acc': 81.1598539352417},\n {'current_epoch': 108,\n 'id': 24,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 1768.9759521484375,\n 'valid_acc': 93.04887652397156}],\n 'parameters': 8.55553,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 106138.55712890625,\n 'valid_acc': 93.04887652397156}\n"
}
],
"source": [
Expand All @@ -63,7 +66,7 @@
" 'input5': [0, 3, 4],\n",
" 'input6': [2, 5]\n",
"}\n",
"for t in query_nb101_trial_stats(arch, 108):\n",
"for t in query_nb101_trial_stats(arch, 108, include_intermediates=True):\n",
" pprint.pprint(t)"
]
},
Expand All @@ -85,14 +88,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the following architecture as an example:<br>\n",
"Use the following architecture as an example:\n",
"\n",
"![nas-201](../../img/nas-bench-201-example.png)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
Expand All @@ -113,6 +119,32 @@
" pprint.pprint(t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Intermediate results are also available."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "{'id': 4, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 12, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 12\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n"
}
],
"source": [
"for t in query_nb201_trial_stats(arch, None, 'imagenet16-120', include_intermediates=True):\n",
" print(t['config'])\n",
" print('Intermediates:', len(t['intermediates']))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -132,8 +164,10 @@
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"execution_count": 5,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
Expand All @@ -156,8 +190,35 @@
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"execution_count": 6,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "[{'current_epoch': 1,\n 'id': 4494501,\n 'test_acc': 41.76,\n 'train_acc': 30.421000000000006,\n 'train_loss': 1.793},\n {'current_epoch': 2,\n 'id': 4494502,\n 'test_acc': 54.66,\n 'train_acc': 47.24,\n 'train_loss': 1.415},\n {'current_epoch': 3,\n 'id': 4494503,\n 'test_acc': 59.97,\n 'train_acc': 56.983,\n 'train_loss': 1.179},\n {'current_epoch': 4,\n 'id': 4494504,\n 'test_acc': 62.91,\n 'train_acc': 61.955,\n 'train_loss': 1.048},\n {'current_epoch': 5,\n 'id': 4494505,\n 'test_acc': 66.16,\n 'train_acc': 64.493,\n 'train_loss': 0.983},\n {'current_epoch': 6,\n 'id': 4494506,\n 'test_acc': 66.5,\n 'train_acc': 66.274,\n 'train_loss': 0.937},\n {'current_epoch': 7,\n 'id': 4494507,\n 'test_acc': 67.55,\n 'train_acc': 67.426,\n 'train_loss': 0.907},\n {'current_epoch': 8,\n 'id': 4494508,\n 'test_acc': 69.45,\n 'train_acc': 68.45400000000001,\n 'train_loss': 0.878},\n {'current_epoch': 9,\n 'id': 4494509,\n 'test_acc': 70.14,\n 'train_acc': 69.295,\n 'train_loss': 0.857},\n {'current_epoch': 10,\n 'id': 4494510,\n 'test_acc': 69.47,\n 'train_acc': 70.304,\n 'train_loss': 0.832}]\n"
}
],
"source": [
"model_spec = {\n",
" 'bot_muls': [0.0, 0.25, 0.25, 0.25],\n",
" 'ds': [1, 16, 1, 4],\n",
" 'num_gs': [1, 2, 1, 2],\n",
" 'ss': [1, 1, 2, 2],\n",
" 'ws': [16, 64, 128, 16]\n",
"}\n",
"for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10', include_intermediates=True):\n",
" pprint.pprint(t['intermediates'][:10])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
Expand All @@ -173,8 +234,10 @@
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"execution_count": 8,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
Expand All @@ -189,8 +252,10 @@
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"execution_count": 9,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -254,8 +319,10 @@
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"execution_count": 10,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
Expand All @@ -270,13 +337,15 @@
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"execution_count": 11,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "Elapsed time: 1.9107539653778076 seconds\n"
"text": "Elapsed time: 2.2023813724517822 seconds\n"
}
],
"source": [
Expand Down
3 changes: 2 additions & 1 deletion src/sdk/pynni/nni/nas/benchmarks/nasbench101/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase

from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps

db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True)

Expand All @@ -28,7 +29,7 @@ class Nb101TrialConfig(Model):
Number of epochs planned for this trial. Should be one of 4, 12, 36, 108 in default setup.
"""

arch = JSONField(index=True)
arch = JSONField(json_dumps=json_dumps, index=True)
num_vertices = IntegerField(index=True)
hash = CharField(max_length=64, index=True)
num_epochs = IntegerField(index=True)
Expand Down
16 changes: 13 additions & 3 deletions src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .graph_util import hash_module, infer_num_vertices


def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None):
def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None, include_intermediates=False):
"""
Query trial stats of NAS-Bench-101 given conditions.

Expand All @@ -24,6 +24,8 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None):
reduction : str or None
If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config.
include_intermediates : boolean
If true, intermediate results will be returned.

Returns
-------
Expand Down Expand Up @@ -56,5 +58,13 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None):
query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None:
query = query.group_by(Nb101TrialStats.config)
for k in query:
yield model_to_dict(k)
for trial in query:
if include_intermediates:
data = model_to_dict(trial)
# exclude 'trial' from intermediates as it is already available in data
data['intermediates'] = [
{k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
]
yield data
else:
yield model_to_dict(trial)
3 changes: 2 additions & 1 deletion src/sdk/pynni/nni/nas/benchmarks/nasbench201/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase

from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps

db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench201.db'), autoconnect=True)

Expand Down Expand Up @@ -35,7 +36,7 @@ class Nb201TrialConfig(Model):
for training, 6k images from validation set for validation and the other 6k for testing).
"""

arch = JSONField(index=True)
arch = JSONField(json_dumps=json_dumps, index=True)
num_epochs = IntegerField(index=True)
num_channels = IntegerField()
num_cells = IntegerField()
Expand Down
16 changes: 13 additions & 3 deletions src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .model import Nb201TrialStats, Nb201TrialConfig


def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None):
def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_intermediates=False):
"""
Query trial stats of NAS-Bench-201 given conditions.

Expand All @@ -23,6 +23,8 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None):
reduction : str or None
If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config.
include_intermediates : boolean
If true, intermediate results will be returned.

Returns
-------
Expand Down Expand Up @@ -53,5 +55,13 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None):
query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None:
query = query.group_by(Nb201TrialStats.config)
for k in query:
yield model_to_dict(k)
for trial in query:
if include_intermediates:
data = model_to_dict(trial)
# exclude 'trial' from intermediates as it is already available in data
data['intermediates'] = [
{k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
]
yield data
else:
yield model_to_dict(trial)
5 changes: 3 additions & 2 deletions src/sdk/pynni/nni/nas/benchmarks/nds/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase

from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps

db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nds.db'), autoconnect=True)

Expand Down Expand Up @@ -52,8 +53,8 @@ class NdsTrialConfig(Model):
'residual_basic',
'vanilla',
])
model_spec = JSONField(index=True)
cell_spec = JSONField(index=True, null=True)
model_spec = JSONField(json_dumps=json_dumps, index=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this modification can resolve the dict element sequence issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comparison of JSON field is based on string comparison, and to compare two dumped JSON of unordered dict, they need to be sorted ahead.

cell_spec = JSONField(json_dumps=json_dumps, index=True, null=True)
dataset = CharField(max_length=15, index=True, choices=['cifar10', 'imagenet'])
generator = CharField(max_length=15, index=True, choices=[
'random',
Expand Down
16 changes: 13 additions & 3 deletions src/sdk/pynni/nni/nas/benchmarks/nds/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_spec, dataset,
num_epochs=None, reduction=None):
num_epochs=None, reduction=None, include_intermediates=False):
"""
Query trial stats of NDS given conditions.

Expand All @@ -32,6 +32,8 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
reduction : str or None
If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config.
include_intermediates : boolean
If true, intermediate results will be returned.

Returns
-------
Expand Down Expand Up @@ -60,5 +62,13 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None:
query = query.group_by(NdsTrialStats.config)
for k in query:
yield model_to_dict(k)
for trial in query:
if include_intermediates:
data = model_to_dict(trial)
# exclude 'trial' from intermediates as it is already available in data
data['intermediates'] = [
{k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
]
yield data
else:
yield model_to_dict(trial)
5 changes: 5 additions & 0 deletions src/sdk/pynni/nni/nas/benchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import functools
import json


json_dumps = functools.partial(json.dumps, sort_keys=True)