Skip to content

Commit

Permalink
Refactor TensorBoard prompt and import tracking and add tests (#15073)
Browse files Browse the repository at this point in the history
  • Loading branch information
joyceerhl authored Jan 7, 2021
1 parent 48abee1 commit 3fd3b9e
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 132 deletions.
7 changes: 4 additions & 3 deletions src/client/tensorBoard/serviceRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import { TensorBoardFileWatcher } from './tensorBoardFileWatcher';
import { TensorBoardImportTracker } from './tensorBoardImportTracker';
import { TensorBoardPrompt } from './tensorBoardPrompt';
import { TensorBoardSessionProvider } from './tensorBoardSessionProvider';
import { ITensorBoardImportTracker } from './types';

export function registerTypes(serviceManager: IServiceManager): void {
serviceManager.addSingleton<IExtensionSingleActivationService>(
Expand All @@ -19,8 +18,10 @@ export function registerTypes(serviceManager: IServiceManager): void {
serviceManager.addSingleton<TensorBoardFileWatcher>(TensorBoardFileWatcher, TensorBoardFileWatcher);
serviceManager.addBinding(TensorBoardFileWatcher, IExtensionSingleActivationService);
serviceManager.addSingleton<TensorBoardPrompt>(TensorBoardPrompt, TensorBoardPrompt);
serviceManager.addSingleton<ITensorBoardImportTracker>(ITensorBoardImportTracker, TensorBoardImportTracker);
serviceManager.addBinding(ITensorBoardImportTracker, IExtensionSingleActivationService);
serviceManager.addSingleton<IExtensionSingleActivationService>(
IExtensionSingleActivationService,
TensorBoardImportTracker,
);
serviceManager.addSingleton<IExtensionSingleActivationService>(
IExtensionSingleActivationService,
TensorBoardCodeLensProvider,
Expand Down
22 changes: 5 additions & 17 deletions src/client/tensorBoard/tensorBoardImportTracker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,24 @@

import { inject, injectable } from 'inversify';
import * as path from 'path';
import { Event, EventEmitter, TextEditor } from 'vscode';
import { TextEditor } from 'vscode';
import { IExtensionSingleActivationService } from '../activation/types';
import { IDocumentManager } from '../common/application/types';
import { isTestExecution } from '../common/constants';
import { IDisposableRegistry } from '../common/types';
import { getDocumentLines } from '../telemetry/importTracker';
import { containsTensorBoardImport } from './helpers';
import { ITensorBoardImportTracker } from './types';
import { TensorBoardPrompt } from './tensorBoardPrompt';

const testExecution = isTestExecution();
@injectable()
export class TensorBoardImportTracker implements ITensorBoardImportTracker, IExtensionSingleActivationService {
private pendingChecks = new Map<string, NodeJS.Timer | number>();

private _onDidImportTensorBoard = new EventEmitter<void>();

export class TensorBoardImportTracker implements IExtensionSingleActivationService {
constructor(
@inject(IDocumentManager) private documentManager: IDocumentManager,
@inject(IDisposableRegistry) private disposables: IDisposableRegistry,
@inject(TensorBoardPrompt) private prompt: TensorBoardPrompt,
) {}

// Fires when the active text editor contains a tensorboard import.
public get onDidImportTensorBoard(): Event<void> {
return this._onDidImportTensorBoard.event;
}

public dispose(): void {
this.pendingChecks.clear();
}

public async activate(): Promise<void> {
if (testExecution) {
await this.activateInternal();
Expand Down Expand Up @@ -63,7 +51,7 @@ export class TensorBoardImportTracker implements ITensorBoardImportTracker, IExt
) {
const lines = getDocumentLines(document);
if (containsTensorBoardImport(lines)) {
this._onDidImportTensorBoard.fire();
this.prompt.showNativeTensorBoardPrompt().ignoreErrors();
}
}
}
Expand Down
12 changes: 4 additions & 8 deletions src/client/tensorBoard/tensorBoardPrompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ import { inject, injectable } from 'inversify';
import { IApplicationShell, ICommandManager } from '../common/application/types';
import { Commands } from '../common/constants';
import { NativeTensorBoard } from '../common/experiments/groups';
import { IDisposableRegistry, IExperimentService, IPersistentState, IPersistentStateFactory } from '../common/types';
import { IExperimentService, IPersistentState, IPersistentStateFactory } from '../common/types';
import { Common, TensorBoard } from '../common/utils/localize';
import { ITensorBoardImportTracker } from './types';

enum TensorBoardPromptStateKeys {
ShowNativeTensorBoardPrompt = 'showNativeTensorBoardPrompt',
Expand All @@ -17,7 +16,7 @@ enum TensorBoardPromptStateKeys {
export class TensorBoardPrompt {
private state: IPersistentState<boolean>;

private enabled: Promise<boolean>;
private enabled: boolean;

private inExperiment: Promise<boolean>;

Expand All @@ -28,8 +27,6 @@ export class TensorBoardPrompt {
constructor(
@inject(IApplicationShell) private applicationShell: IApplicationShell,
@inject(ICommandManager) private commandManager: ICommandManager,
@inject(ITensorBoardImportTracker) private importTracker: ITensorBoardImportTracker,
@inject(IDisposableRegistry) private disposableRegistry: IDisposableRegistry,
@inject(IPersistentStateFactory) private persistentStateFactory: IPersistentStateFactory,
@inject(IExperimentService) private experimentService: IExperimentService,
) {
Expand All @@ -39,13 +36,12 @@ export class TensorBoardPrompt {
);
this.enabled = this.isPromptEnabled();
this.inExperiment = this.isInExperiment();
this.importTracker.onDidImportTensorBoard(this.showNativeTensorBoardPrompt, this, this.disposableRegistry);
}

public async showNativeTensorBoardPrompt(): Promise<void> {
if (
(await this.inExperiment) &&
(await this.enabled) &&
this.enabled &&
this.enabledInCurrentSession &&
!this.waitingForUserSelection
) {
Expand Down Expand Up @@ -73,7 +69,7 @@ export class TensorBoardPrompt {
}
}

private async isPromptEnabled(): Promise<boolean> {
private isPromptEnabled(): boolean {
return this.state.value;
}

Expand Down
110 changes: 52 additions & 58 deletions src/client/tensorBoard/tensorBoardSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { createPromiseFromCancellation } from '../common/cancellation';
import { traceError, traceInfo } from '../common/logger';
import { tensorboardLauncher } from '../common/process/internal/scripts';
import { IProcessServiceFactory, ObservableExecutionResult } from '../common/process/types';
import { IInstaller, InstallerResponse, Product } from '../common/types';
import { IDisposableRegistry, IInstaller, InstallerResponse, Product } from '../common/types';
import { createDeferred, sleep } from '../common/utils/async';
import { TensorBoard } from '../common/utils/localize';
import { IInterpreterService } from '../interpreter/contracts';
Expand Down Expand Up @@ -47,6 +47,7 @@ export class TensorBoardSession {
private readonly workspaceService: IWorkspaceService,
private readonly processServiceFactory: IProcessServiceFactory,
private readonly commandManager: ICommandManager,
private readonly disposables: IDisposableRegistry,
) {}

public async initialize(): Promise<void> {
Expand Down Expand Up @@ -241,67 +242,60 @@ export class TensorBoardSession {
const webviewPanel = window.createWebviewPanel('tensorBoardSession', 'TensorBoard', ViewColumn.Two, {
enableScripts: true,
});
webviewPanel.webview.html = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta http-equiv="Content-Security-Policy" content="default-src 'unsafe-inline'; frame-src ${this.url} http: https:;">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>TensorBoard</title>
</head>
<body>
<script type="text/javascript">
function resizeFrame() {
var f = window.document.getElementById('vscode-tensorboard-iframe');
if (f) {
f.style.height = window.innerHeight / 0.7 + "px";
f.style.width = window.innerWidth / 0.7 + "px";
}
}
resizeFrame();
window.addEventListener('resize', resizeFrame);
</script>
<iframe
id="vscode-tensorboard-iframe"
class="responsive-iframe"
sandbox="allow-scripts allow-forms allow-same-origin allow-pointer-lock"
src="${this.url}"
frameborder="0"
border="0"
allowfullscreen
></iframe>
<style>
.responsive-iframe {
transform: scale(0.7);
transform-origin: 0 0;
position: absolute;
top: 0;
left: 0;
overflow: hidden;
display: block;
}
</style>
</body>
</html>`;
this.webviewPanel = webviewPanel;
webviewPanel.onDidDispose(() => {
this.webviewPanel = undefined;
// Kill the running TensorBoard session
this.process?.kill();
this.process = undefined;
});
webviewPanel.onDidChangeViewState(() => {
if (webviewPanel.visible) {
this.update();
}
}, null);
this.disposables.push(
webviewPanel.onDidDispose(() => {
this.webviewPanel = undefined;
// Kill the running TensorBoard session
this.process?.kill();
this.process = undefined;
}),
);
return webviewPanel;
}

private update() {
if (this.webviewPanel) {
this.webviewPanel.webview.html = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta http-equiv="Content-Security-Policy" content="default-src 'unsafe-inline'; frame-src ${this.url};">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>TensorBoard</title>
</head>
<body>
<script type="text/javascript">
function resizeFrame() {
var f = window.document.getElementById('vscode-tensorboard-iframe');
if (f) {
f.style.height = window.innerHeight / 0.7 + "px";
f.style.width = window.innerWidth / 0.7 + "px";
}
}
window.addEventListener('resize', resizeFrame);
</script>
<iframe
id="vscode-tensorboard-iframe"
class="responsive-iframe"
sandbox="allow-scripts allow-forms allow-same-origin allow-pointer-lock"
src="${this.url}"
frameborder="0"
border="0"
allowfullscreen
></iframe>
<style>
.responsive-iframe {
transform: scale(0.7);
transform-origin: 0 0;
position: absolute;
top: 0;
left: 0;
overflow: hidden;
display: block;
}
</style>
</body>
</html>`;
}
}

private autopopulateLogDirectoryPath(): string | undefined {
if (this.workspaceService.rootPath) {
return this.workspaceService.rootPath;
Expand Down
1 change: 1 addition & 0 deletions src/client/tensorBoard/tensorBoardSessionProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export class TensorBoardSessionProvider implements IExtensionSingleActivationSer
this.workspaceService,
this.processServiceFactory,
this.commandManager,
this.disposables,
);
await newSession.initialize();
} catch (e) {
Expand Down
24 changes: 24 additions & 0 deletions src/test/tensorBoard/helpers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import * as TypeMoq from 'typemoq';
import { IApplicationShell, ICommandManager } from '../../client/common/application/types';
import { IExperimentService, IPersistentStateFactory } from '../../client/common/types';
import { TensorBoardPrompt } from '../../client/tensorBoard/tensorBoardPrompt';
import { MockState } from '../interpreters/mocks';

export function createTensorBoardPromptWithMocks(): TensorBoardPrompt {
const appShell = TypeMoq.Mock.ofType<IApplicationShell>();
const commandManager = TypeMoq.Mock.ofType<ICommandManager>();
const persistentStateFactory = TypeMoq.Mock.ofType<IPersistentStateFactory>();
const expService = TypeMoq.Mock.ofType<IExperimentService>();
const persistentState = new MockState(true);
persistentStateFactory
.setup((factory) => {
factory.createWorkspacePersistentState(TypeMoq.It.isAny(), TypeMoq.It.isAny());
})
.returns(() => persistentState);
return new TensorBoardPrompt(
appShell.object,
commandManager.object,
persistentStateFactory.object,
expService.object,
);
}
47 changes: 20 additions & 27 deletions src/test/tensorBoard/tensorBoardImportTracker.unit.test.ts
Original file line number Diff line number Diff line change
@@ -1,86 +1,79 @@
import { assert } from 'chai';
import * as sinon from 'sinon';
import { TensorBoardImportTracker } from '../../client/tensorBoard/tensorBoardImportTracker';
import { TensorBoardPrompt } from '../../client/tensorBoard/tensorBoardPrompt';
import { MockDocumentManager } from '../startPage/mockDocumentManager';
import { createTensorBoardPromptWithMocks } from './helpers';

suite('TensorBoard import tracker', () => {
let documentManager: MockDocumentManager;
let tensorBoardImportTracker: TensorBoardImportTracker;
let onDidImportTensorBoardListener: sinon.SinonExpectation;
let prompt: TensorBoardPrompt;
let showNativeTensorBoardPrompt: sinon.SinonSpy;

setup(() => {
documentManager = new MockDocumentManager();
tensorBoardImportTracker = new TensorBoardImportTracker(documentManager, []);
onDidImportTensorBoardListener = sinon.expectation.create('onDidImportTensorBoardListener');
tensorBoardImportTracker.onDidImportTensorBoard(onDidImportTensorBoardListener);
prompt = createTensorBoardPromptWithMocks();
showNativeTensorBoardPrompt = sinon.spy(prompt, 'showNativeTensorBoardPrompt');
tensorBoardImportTracker = new TensorBoardImportTracker(documentManager, [], prompt);
});

test('Simple tensorboard import in Python file', async () => {
const document = documentManager.addDocument('import tensorboard', 'foo.py');
await documentManager.showTextDocument(document);
await tensorBoardImportTracker.activate();
onDidImportTensorBoardListener.once().verify();
assert.ok(showNativeTensorBoardPrompt.calledOnce);
});
test('Simple tensorboard import in Python ipynb', async () => {
const document = documentManager.addDocument('import tensorboard', 'foo.ipynb');
await documentManager.showTextDocument(document);
await tensorBoardImportTracker.activate();
onDidImportTensorBoardListener.once().verify();
assert.ok(showNativeTensorBoardPrompt.calledOnce);
});
test('`from x.y.tensorboard import z` import', async () => {
const document = documentManager.addDocument('from torch.utils.tensorboard import SummaryWriter', 'foo.py');
await documentManager.showTextDocument(document);
await tensorBoardImportTracker.activate();
onDidImportTensorBoardListener.once().verify();
assert.ok(showNativeTensorBoardPrompt.calledOnce);
});
test('`from x.y import tensorboard` import', async () => {
const document = documentManager.addDocument('from torch.utils import tensorboard', 'foo.py');
await documentManager.showTextDocument(document);
await tensorBoardImportTracker.activate();
onDidImportTensorBoardListener.once().verify();
assert.ok(showNativeTensorBoardPrompt.calledOnce);
});
test('`import x, y` import', async () => {
const document = documentManager.addDocument('import tensorboard, tensorflow', 'foo.py');
await documentManager.showTextDocument(document);
await tensorBoardImportTracker.activate();
onDidImportTensorBoardListener.once().verify();
assert.ok(showNativeTensorBoardPrompt.calledOnce);
});
test('`import pkg as _` import', async () => {
const document = documentManager.addDocument('import tensorboard as tb', 'foo.py');
await documentManager.showTextDocument(document);
await tensorBoardImportTracker.activate();
onDidImportTensorBoardListener.once().verify();
assert.ok(showNativeTensorBoardPrompt.calledOnce);
});
test('Fire on changed text editor', async () => {
test('Show prompt on changed text editor', async () => {
await tensorBoardImportTracker.activate();
const document = documentManager.addDocument('import tensorboard as tb', 'foo.py');
await documentManager.showTextDocument(document);
onDidImportTensorBoardListener.once().verify();
assert.ok(showNativeTensorBoardPrompt.calledOnce);
});
test('Do not fire event if no tensorboard import', async () => {
test('Do not show prompt if no tensorboard import', async () => {
const document = documentManager.addDocument('import tensorflow as tf\nfrom torch.utils import foo', 'foo.py');
await documentManager.showTextDocument(document);
await tensorBoardImportTracker.activate();
onDidImportTensorBoardListener.never().verify();
assert.ok(showNativeTensorBoardPrompt.notCalled);
});
test('Do not fire event if language is not Python', async () => {
test('Do not show prompt if language is not Python', async () => {
const document = documentManager.addDocument(
'import tensorflow as tf\nfrom torch.utils import foo',
'foo.cpp',
'cpp',
);
await documentManager.showTextDocument(document);
await tensorBoardImportTracker.activate();
onDidImportTensorBoardListener.never().verify();
});
test('Ignore docstrings', async () => {
const document = documentManager.addDocument(
`"""
import tensorboard
"""`,
'foo.py',
);
await documentManager.showTextDocument(document);
await tensorBoardImportTracker.activate();
onDidImportTensorBoardListener.never().verify();
assert.ok(showNativeTensorBoardPrompt.notCalled);
});
});
Loading

0 comments on commit 3fd3b9e

Please sign in to comment.