Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
webgl: support iOS devices (#132)
Browse files Browse the repository at this point in the history
* webgl: enable check in debug mode

* webgl: allow half_float textures

* update condition for forceReadUint8

* fix isNaN() in ios12

* test: update whitelist for iOS

* use high precision sampler2D by default

* webgl: stop using window.WebGLRenderingContext

* test: update karma config for BrowserStack

* test: optimize randomArray in conv test
  • Loading branch information
fs-eire authored Apr 23, 2019
1 parent 2a1354a commit c1c8f90
Show file tree
Hide file tree
Showing 16 changed files with 355 additions and 251 deletions.
118 changes: 91 additions & 27 deletions karma.conf.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,36 @@
module.exports = function(config) {
const bundleMode = require('minimist')(process.argv)['bundle-mode'] || 'dev'; // 'dev'|'perf'|undefined;
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

const bundleMode = require('minimist')(process.argv)['bundle-mode'] || 'dev'; // 'dev'|'perf'|undefined;
const mainFile = bundleMode === 'perf' ? 'test/onnx.perf.js' : 'test/onnx.dev.js';

// it's a known issue that Safari does not work with "localhost" in BrowserStack:
// https://www.browserstack.com/question/663
//
// we need to read machine IP address to replace "localhost":
// https://stackoverflow.com/a/8440736
//
function getMachineIpAddress() {
var os = require('os');
var ifaces = os.networkInterfaces();

for (const ifname in ifaces) {
for (const iface of ifaces[ifname]) {
if ('IPv4' !== iface.family || iface.internal !== false) {
// skip over internal (i.e. 127.0.0.1) and non-ipv4 addresses
continue;
}

// returns the first available IP address
return iface.address;
}
}

const mainFile = bundleMode === 'perf' ? 'test/onnx.perf.js' : 'test/onnx.dev.js';
// if no available IP address, fallback to "localhost".
return 'localhost';
}

module.exports = function(config) {
config.set({
// global config of your BrowserStack account
browserStack: {
Expand All @@ -27,64 +55,100 @@ module.exports = function(config) {
},
client: {captureConsole: true, mocha: {expose: ['body'], timeout: 60000}},
preprocessors: {mainFile: ['sourcemap']},
reporters: ['mocha'],
browsers: [
'ChromeTest',
'ChromeDebug',
'Edge',
'Firefox',
'Electron',
'Safari',
'BS_WIN_Chrome',
'BS_WIN_Edge',
'BS_WIN_Firefox',
'BS_MAC_Chrome',
'BS_MAC_Safari',
],
reporters: ['mocha', 'BrowserStack'],
browsers: [],
captureTimeout: 120000,
reportSlowerThan: 100,
browserDisconnectTimeout: 600000,
browserNoActivityTimeout: 300000,
browserDisconnectTolerance: 0,
browserSocketTimeout: 60000,
hostname: getMachineIpAddress(),
customLaunchers: {
ChromeTest: {base: 'Chrome', flags: ['--window-size=1,1']},
ChromeDebug: {debug: true, base: 'Chrome', flags: ['--remote-debugging-port=9333']},
BS_WIN_Chrome: {
//
// ==== BrowserStack browsers ====
//

// Windows
//
BS_WIN_10_Chrome_73: {
base: 'BrowserStack',
browser: 'Chrome',
browser_version: '71.0',
browser_version: '73.0',
os: 'Windows',
os_version: '10',
},
BS_WIN_Edge: {
BS_WIN_10_Edge_18: {
base: 'BrowserStack',
os: 'Windows',
os_version: '10',
browser: 'Edge',
browser_version: '18.0',
},
BS_WIN_Firefox: {
BS_WIN_10_Firefox_66: {
base: 'BrowserStack',
os: 'Windows',
os_version: '10',
browser: 'Firefox',
browser_version: '63.0',
browser_version: '66.0',
},
BS_MAC_Chrome: {
BS_WIN_7_Chrome_63: {
base: 'BrowserStack',
browser: 'Chrome',
browser_version: '71.0',
browser_version: '63.0',
os: 'Windows',
os_version: '7',
},

// macOS
//
BS_MAC_10_14_Safari_12: {
base: 'BrowserStack',
os: 'OS X',
os_version: 'High Sierra',
os_version: 'Mojave',
browser: 'Safari',
browser_version: '12.0',
},
BS_MAC_10_14_Chrome_73: {
base: 'BrowserStack',
os: 'OS X',
os_version: 'Mojave',
browser: 'Chrome',
browser_version: '73.0',
},
BS_MAC_Safari: {
BS_MAC_10_13_Safari_11_1: {
base: 'BrowserStack',
os: 'OS X',
os_version: 'High Sierra',
browser: 'Safari',
browser_version: '11.1',
}
},

// iPhone
//
BS_IOS_12_1_iPhoneXS: {
base: 'BrowserStack',
device: 'iPhone XS',
real_mobile: true,
os: 'ios',
os_version: '12.1',
},
BS_IOS_11_iPhoneX: {
base: 'BrowserStack',
device: 'iPhone X',
real_mobile: true,
os: 'ios',
os_version: '11',
},
BS_IOS_10_3_iPhone7: {
base: 'BrowserStack',
device: 'iPhone 7',
real_mobile: true,
os: 'ios',
os_version: '10.3',
},
}
});
};
8 changes: 3 additions & 5 deletions lib/backends/backend-webgl.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import * as platform from 'platform';

import {Backend as BackendInterface} from '../api/onnx';
import {Backend, SessionHandler} from '../backend';
import {Logger} from '../instrument';
import {Session} from '../session';

import {WebGLSessionHandler} from './webgl/session-handler';
import {WebGLContext} from './webgl/webgl-context';
import {WebGLContextFactory} from './webgl/webgl-context-factory';
import {createWebGLContext} from './webgl/webgl-context-factory';

type WebGLOptions = BackendInterface.WebGLOptions;

Expand All @@ -27,10 +25,10 @@ export class WebGLBackend implements Backend, WebGLOptions {

initialize(): boolean {
try {
if (platform.name === 'Safari') {
this.glContext = createWebGLContext(this.contextId);
if (!this.glContext.floatDownloadEnabled) {
this.forceUint8Reads = true;
}
this.glContext = WebGLContextFactory.create(this.contextId);
Logger.verbose('WebGLBackend', `Created WebGLContext: ${typeof this.glContext}`);
return true;
} catch (e) {
Expand Down
1 change: 1 addition & 0 deletions lib/backends/webgl/glsl-preprocessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export class GlslPreprocessor {
return `
precision highp float;
precision highp int;
precision highp sampler2D;
varying vec2 TexCoords;
${script}
Expand Down
16 changes: 6 additions & 10 deletions lib/backends/webgl/inference-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {WebGLUint8Encode} from './ops/uint8-encode';
import {ProgramManager} from './program-manager';
import {WebGLSessionHandler} from './session-handler';
import {TextureData, TextureLayout} from './texture-data';
import {Encoder} from './texture-data-encoder';
import {TextureHelper} from './texture-helper';
import {WidthHeightPrefs} from './texture-layout-strategy';
import {getPackedShape} from './utils';
Expand Down Expand Up @@ -43,7 +44,8 @@ export class WebGLInferenceHandler implements InferenceHandler {
if (!layout) {
layout = this.createBasicTextureLayout(tensor.dims.slice());
}
td = this.createTextureDataFromLayout(layout, tensor.type, tensor.numberData);
// graph inputs or initializers
td = this.createTextureDataFromLayout(layout, tensor.type, tensor.numberData, Encoder.Usage.UploadOnly);
this.setTextureData(tensor, td);
} else {
Logger.verbose('InferenceHandler', `Retrieving TextureData from cache: [${tensor.dims}]`);
Expand Down Expand Up @@ -95,16 +97,10 @@ export class WebGLInferenceHandler implements InferenceHandler {
this.tensorToTexture = new Map();
this.textureToTensor = new Map();
}
createTextureData(
dataType: Tensor.DataType, shape: ReadonlyArray<number>, strides?: ReadonlyArray<number>,
data?: Tensor.NumberType, channels?: number, width?: number, height?: number): TextureData {
Logger.verbose('InferenceHandler', `Creating TextureData: shape:[${shape}], channels:${channels ? channels : 1}`);
const td = this.textureHelper.createTexture(dataType, shape, strides, data, channels, width, height);
return td;
}
createTextureDataFromLayout(layout: TextureLayout, dataType: Tensor.DataType, data?: Tensor.NumberType): TextureData {
createTextureDataFromLayout(
layout: TextureLayout, dataType: Tensor.DataType, data?: Tensor.NumberType, usage?: Encoder.Usage): TextureData {
Logger.verbose('InferenceHandler', `Creating TextureData: layout:[${JSON.stringify(layout)}]`);
const td = this.textureHelper.createTextureFromLayout(dataType, layout, data);
const td = this.textureHelper.createTextureFromLayout(dataType, layout, data, usage);
return td;
}
createBasicTextureLayout(
Expand Down
4 changes: 3 additions & 1 deletion lib/backends/webgl/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo} from '../program-info';
import {Artifact, RunData} from '../program-manager';
import {TextureLayout} from '../texture-data';
import {Encoder} from '../texture-data-encoder';
import {WebGLContext} from '../webgl-context';

export class WebGLConv extends Conv {
Expand Down Expand Up @@ -56,7 +57,8 @@ export class WebGLConv extends Conv {
Logger.verbose('Conv', 'Did not find the adjustedKernel texture in the cache. Creating rew.');
const newKernelData =
WebGLConv.prepKernelForDotProduct(k.dims.slice(), this.group, 4, k.floatData as Float32Array);
kTD = inferenceHandler.createTextureDataFromLayout(programInfos[1].inputLayouts[1], k.type, newKernelData);
kTD = inferenceHandler.createTextureDataFromLayout(
programInfos[1].inputLayouts[1], k.type, newKernelData, Encoder.Usage.UploadOnly);
inferenceHandler.setTextureData(k, kTD);
}
const runtDataIm2Col = {
Expand Down
6 changes: 4 additions & 2 deletions lib/backends/webgl/ops/uint8-encode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export class WebGLUint8Encode {
strides: ShapeUtil.computeStrides(outputShape),
unpackedShape: outputShape
};
// TODO: remove this special script. Use graph transformer instead.
/**
* https://github.com/tensorflow/tfjs-core/blob/master/src/kernels/webgl/encode_float_gpu.ts
*/
Expand All @@ -27,7 +28,7 @@ export class WebGLUint8Encode {
uniform sampler2D X;
bool isNaN(float val) {
return (val < 0.0 || 0.0 < val || val == 0.0) ? false : true;
return (val < 1.0 || 0.0 < val || val == 0.0) ? false : true;
}
highp vec4 encodeAsUint8(highp float v) {
Expand Down Expand Up @@ -78,8 +79,9 @@ export class WebGLUint8Encode {
};
const artifact = inferenceHandler.programManager.build(programInfo);

const encoder = inferenceHandler.backend.glContext.getEncoder('byte', 4);
const texture =
inferenceHandler.backend.glContext.allocateTexture(outputLayout.width, outputLayout.height, 'byte', 4);
inferenceHandler.backend.glContext.allocateTexture(outputLayout.width, outputLayout.height, encoder);
const outputTextureData: TextureData = {...outputLayout, dataType: 'uint8', texture};
const runData = {inputTextureDatas: [input], outputTextureData, uniformData: {}};

Expand Down
4 changes: 2 additions & 2 deletions lib/backends/webgl/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ export class ProgramManager {
if (!this.vertexShader) {
Logger.verbose('ProrgramManager', 'Compiling and caching Vertex shader for the first time');
this.vertexShader =
this.glContext.compileShader(this.getDefaultVertexShaderSource(), WebGLRenderingContext.VERTEX_SHADER);
this.glContext.compileShader(this.getDefaultVertexShaderSource(), this.glContext.gl.VERTEX_SHADER);
}
const fragShader = this.glContext.compileShader(fragShaderScript, WebGLRenderingContext.FRAGMENT_SHADER);
const fragShader = this.glContext.compileShader(fragShaderScript, this.glContext.gl.FRAGMENT_SHADER);
const program = this.glContext.createProgram(this.vertexShader, fragShader);
this.glContext.deleteShader(fragShader);
return program;
Expand Down
Loading

0 comments on commit c1c8f90

Please sign in to comment.