Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] Visualization #3878

Merged
merged 15 commits into from
Jul 12, 2021
2 changes: 2 additions & 0 deletions docs/en_US/NAS/QuickStart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ Visualize the Experiment

Users can visualize their experiment in the same way as visualizing a normal hyper-parameter tuning experiment. For example, open ``localhost::8081`` in your browser, 8081 is the port that you set in ``exp.run``. Please refer to `here <../../Tutorial/WebUI.rst>`__ for details.

We support visualizing models with 3rd-party visualization engines (like `Netron <https://netron.app/>`__). This can be used by clicking ``Visualization`` in detail panel for each trial. Note that current visualization is based on `onnx <https://onnx.ai/>`__ . Built-in evaluators (e.g., Classification) will automatically export the model into a file, for your own evaluator, you need to save your file into ``$NNI_OUTPUT_DIR/model.onnx`` to make this work.

Export Top Models
-----------------

Expand Down
2 changes: 2 additions & 0 deletions docs/en_US/NAS/WriteTrainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ The simplest way to customize a new evaluator is with functional APIs, which is

.. note:: Due to our current implementation limitation, the ``fit`` function should be put in another python file instead of putting it in the main file. This limitation will be fixed in future release.

.. note:: When using customized evaluators, if you want to visualize models, you need to export your model and save it into ``$NNI_OUTPUT_DIR/model.onnx`` in your evaluator.

With PyTorch-Lightning
----------------------

Expand Down
11 changes: 6 additions & 5 deletions nni/experiment/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from subprocess import Popen
import sys
import time
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Any

import colorama

Expand Down Expand Up @@ -43,7 +43,7 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
_check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, str(config.experiment_working_directory))
config.experiment_name, proc.pid, str(config.experiment_working_directory), [])
_logger.info('Setting up...')
rest.post(port, '/experiment', config.json())
return proc
Expand Down Expand Up @@ -78,7 +78,7 @@ def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int,
_check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, config.experiment_working_directory)
config.experiment_name, proc.pid, config.experiment_working_directory, ['retiarii'])
_logger.info('Setting up...')
rest.post(port, '/experiment', config.json())
return proc, pipe
Expand Down Expand Up @@ -156,9 +156,10 @@ def _check_rest_server(port: int, retry: int = 3) -> None:
rest.get(port, '/check-status')


def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None:
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str,
name: str, pid: int, logDir: str, tag: List[Any]) -> None:
experiments_config = Experiments()
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir)
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir, tag=tag)


def get_stopped_experiment_config(exp_id: str, mode: str) -> None:
Expand Down
56 changes: 44 additions & 12 deletions nni/retiarii/evaluator/pytorch/lightning.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import warnings
from typing import Dict, Union, Optional, List
from pathlib import Path
from typing import Dict, NoReturn, Union, Optional, List, Type

import pytorch_lightning as pl
import torch.nn as nn
Expand All @@ -18,7 +20,13 @@


class LightningModule(pl.LightningModule):
def set_model(self, model):
"""
Basic wrapper of generated model.

Lightning modules used in NNI should inherit this class.
"""

def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> NoReturn:
if isinstance(model, type):
self.model = model()
else:
Expand Down Expand Up @@ -112,13 +120,23 @@ class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
optimizer: optim.Optimizer = optim.Adam,
export_onnx: Union[Path, str, bool, None] = None):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion()
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})

if export_onnx is None or export_onnx is True:
self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx'
self.export_onnx.parent.mkdir(exist_ok=True)
elif export_onnx:
self.export_onnx = Path(export_onnx)
else:
self.export_onnx = None
self._already_exported = False

def forward(self, x):
y_hat = self.model(x)
return y_hat
Expand All @@ -135,6 +153,11 @@ def training_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)

if not self._already_exported:
self.to_onnx(self.export_onnx, x, export_params=True)
self._already_exported = True

self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
self.log('val_' + name, metric(y_hat, y), prog_bar=True)
Expand All @@ -152,9 +175,8 @@ def configure_optimizers(self):
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())

def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self._get_validation_metrics())
def on_fit_end(self):
nni.report_final_result(self._get_validation_metrics())

def _get_validation_metrics(self):
if len(self.metrics) == 1:
Expand All @@ -175,9 +197,11 @@ class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)


class Classification(Lightning):
Expand All @@ -200,6 +224,8 @@ class Classification(Lightning):
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default true
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
Expand All @@ -211,9 +237,10 @@ def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
**trainer_kwargs):
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)

Expand All @@ -223,9 +250,11 @@ class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)


class Regression(Lightning):
Expand All @@ -248,6 +277,8 @@ class Regression(Lightning):
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default: true
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
Expand All @@ -259,8 +290,9 @@ def __init__(self, criterion: nn.Module = nn.MSELoss,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
**trainer_kwargs):
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
1 change: 1 addition & 0 deletions test/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ _generated_model
data
generated
lightning_logs
model.onnx
2 changes: 1 addition & 1 deletion test/ut/tools/nnictl/mock/restful_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def mock_get_latest_metric_data():

def mock_get_trial_log():
responses.add(
responses.DELETE, 'http://localhost:8080/api/v1/nni/trial-log/:id/:type',
responses.DELETE, 'http://localhost:8080/api/v1/nni/trial-file/:id/:filename',
json={"status":"RUNNING","errors":[]},
status=200,
content_type='application/json',
Expand Down
4 changes: 2 additions & 2 deletions ts/nni_manager/common/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
'use strict';

import { MetricDataRecord, MetricType, TrialJobInfo } from './datastore';
import { TrialJobStatus, LogType } from './trainingService';
import { TrialJobStatus } from './trainingService';
import { ExperimentConfig } from './experimentConfig';

type ProfileUpdateType = 'TRIAL_CONCURRENCY' | 'MAX_EXEC_DURATION' | 'SEARCH_SPACE' | 'MAX_TRIAL_NUM';
Expand Down Expand Up @@ -59,7 +59,7 @@ abstract class Manager {
public abstract getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise<MetricDataRecord[]>;
public abstract getLatestMetricData(): Promise<MetricDataRecord[]>;

public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>;
public abstract getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior would be more predictable to split into get text and get binary.
If you have time, not mandatory...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having a hard time thinking of another meaningful name.
Let's do it in future when we actually need this one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a broken change


public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>;
public abstract getStatus(): NNIManagerStatus;
Expand Down
6 changes: 2 additions & 4 deletions ts/nni_manager/common/trainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
*/
type TrialJobStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED' | 'SYS_CANCELED' | 'EARLY_STOPPED';

type LogType = 'TRIAL_LOG' | 'TRIAL_STDOUT' | 'TRIAL_ERROR';

interface TrainingServiceMetadata {
readonly key: string;
readonly value: string;
Expand Down Expand Up @@ -81,7 +79,7 @@ abstract class TrainingService {
public abstract submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise<void>;
public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>;
public abstract getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string>;
public abstract setClusterMetadata(key: string, value: string): Promise<void>;
public abstract getClusterMetadata(key: string): Promise<string>;
public abstract getTrialOutputLocalPath(trialJobId: string): Promise<string>;
Expand All @@ -103,5 +101,5 @@ class NNIManagerIpConfig {
export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
NNIManagerIpConfig, LogType
NNIManagerIpConfig
};
6 changes: 3 additions & 3 deletions ts/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/exp
import { ExperimentManager } from '../common/experimentManager';
import { TensorboardManager } from '../common/tensorboardManager';
import {
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../common/trainingService';
import { delay, getCheckpointDir, getExperimentRootDir, getLogDir, getMsgDispatcherCommand, mkDirP, getTunerProc, getLogLevel, isAlive, killPid } from '../common/utils';
import {
Expand Down Expand Up @@ -403,8 +403,8 @@ class NNIManager implements Manager {
// FIXME: unit test
}

public async getTrialLog(trialJobId: string, logType: LogType): Promise<string> {
return this.trainingService.getTrialLog(trialJobId, logType);
public async getTrialFile(trialJobId: string, fileName: string): Promise<Buffer | string> {
return this.trainingService.getTrialFile(trialJobId, fileName);
}

public getExperimentProfile(): Promise<ExperimentProfile> {
Expand Down
4 changes: 2 additions & 2 deletions ts/nni_manager/core/test/mockedTrainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { Deferred } from 'ts-deferred';
import { Provider } from 'typescript-ioc';

import { MethodNotImplementedError } from '../../common/errors';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType } from '../../common/trainingService';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from '../../common/trainingService';

const testTrainingServiceProvider: Provider = {
get: () => { return new MockedTrainingService(); }
Expand Down Expand Up @@ -63,7 +63,7 @@ class MockedTrainingService extends TrainingService {
return deferred.promise;
}

public getTrialLog(trialJobId: string, logType: LogType): Promise<string> {
public getTrialFile(trialJobId: string, fileName: string): Promise<string> {
throw new MethodNotImplementedError();
}

Expand Down
2 changes: 2 additions & 0 deletions ts/nni_manager/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"child-process-promise": "^2.2.1",
"express": "^4.17.1",
"express-joi-validator": "^2.0.1",
"http-proxy": "^1.18.1",
"ignore": "^5.1.8",
"js-base64": "^3.6.1",
"kubernetes-client": "^6.12.1",
Expand All @@ -37,6 +38,7 @@
"@types/chai-as-promised": "^7.1.0",
"@types/express": "^4.17.2",
"@types/glob": "^7.1.3",
"@types/http-proxy": "^1.17.7",
"@types/js-base64": "^3.3.1",
"@types/js-yaml": "^4.0.1",
"@types/lockfile": "^1.0.0",
Expand Down
11 changes: 11 additions & 0 deletions ts/nni_manager/rest_server/nniRestServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import * as bodyParser from 'body-parser';
import * as express from 'express';
import * as httpProxy from 'http-proxy';
import * as path from 'path';
import * as component from '../common/component';
import { RestServer } from '../common/restServer'
Expand All @@ -21,6 +22,7 @@ import { getAPIRootUrl } from '../common/experimentStartupInfo';
@component.Singleton
export class NNIRestServer extends RestServer {
private readonly LOGS_ROOT_URL: string = '/logs';
protected netronProxy: any = null;
protected API_ROOT_URL: string = '/api/v1/nni';

/**
Expand All @@ -29,6 +31,7 @@ export class NNIRestServer extends RestServer {
constructor() {
super();
this.API_ROOT_URL = getAPIRootUrl();
this.netronProxy = httpProxy.createProxyServer();
}

/**
Expand All @@ -39,6 +42,14 @@ export class NNIRestServer extends RestServer {
this.app.use(bodyParser.json({limit: '50mb'}));
this.app.use(this.API_ROOT_URL, createRestHandler(this));
this.app.use(this.LOGS_ROOT_URL, express.static(getLogDir()));
this.app.all('/netron/*', (req: express.Request, res: express.Response) => {
delete req.headers.host;
req.url = req.url.replace('/netron', '/');
this.netronProxy.web(req, res, {
changeOrigin: true,
target: 'https://netron.app'
});
});
this.app.get('*', (req: express.Request, res: express.Response) => {
res.sendFile(path.resolve('static/index.html'));
});
Expand Down
Loading